In [1]:

!pip install datasets
import pandas as pd
import ast  # To safely evaluate the string representation of lists
from sentence_transformers import util
from sentence_transformers import SentenceTransformer, InputExample, losses, LoggingHandler, models, SimilarityFunction
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from torch.utils.data import DataLoader
import logging
import math
import os
import random
import nltk
import os
import sys
import datetime
import argparse
import json
from datasets import Dataset # <--- ADDED IMPORT


from google.colab import drive, userdata
drive.mount('/content/drive')
# Adjust path as needed
BASE_DIR = '/content/drive/My Drive/SUNY_Poly_DSA598/'


# --- Setup Logging ---
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)

# --- Download NLTK sentence tokenizer if needed ---
nltk.download('punkt')
nltk.download('punkt_tab')

hf_key = userdata.get('hf_key')
os.environ['HUGGINGFACE_TOKEN'] = hf_key

Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.5.0-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.12.0-py3-none-any.wh

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


In [21]:
date_time_obj = datetime.datetime.now()
date_str = date_time_obj.strftime("%m-%d_%H%M")

# --- Configuration ---
# SentenceTransformer
MODEL_NAME = 'sentence-transformers/all-mpnet-base-v2'

# Training Parameters
NUM_EPOCHS = 2         # Number of training epochs (1-4 is often sufficient)
TRAIN_BATCH_SIZE = 32  # Adjust based on GPU memory (larger often better for MNRL)
LEARNING_RATE = 2e-5   # Standard learning rate for fine-tuning transformers
WARMUP_STEPS = 100     # Number of warmup steps for the learning rate scheduler
EVAL_STEPS = 64       # Evaluate performance every N steps (if dev set provided)
MAX_SAMPLES = 1024      # Limit the number of training samples (for testing)
MAX_VALID_SAMPLES = 36

TRAIN_CSV_PATH = os.path.join(BASE_DIR, 'datasets/FEVER/tabular_sets/tabular_sentEx_paper_dev_train/v1_segmented_sentIDs_n3461_04-04_002.csv')
DEV_CSV_PATH = os.path.join(BASE_DIR, 'datasets/FEVER/tabular_sets/tabular_sentEx_paper_dev_valid/v1_segmented_sentIDs_n1482_04-04_002.csv' ) # Use same for dev set for testing
OUTPUT_PATH = os.path.join(BASE_DIR, f'models/sBERT/{MODEL_NAME.replace("/", "-")}_n{MAX_SAMPLES}_{date_str}')

verbose = True

# --- Load and Prepare Data ---

def load_and_prepare_data(csv_path, is_eval_set=False, max_samples=None, verbose=False):
    """Loads data, filters, parses evidence, and creates InputExamples or evaluation queries/corpus."""
    try:
        df = pd.read_csv(csv_path)
        logger.info(f"Loaded {len(df)} rows from {csv_path}")
    except FileNotFoundError:
        logger.error(f"Error: File not found at {csv_path}")
        return None if not is_eval_set else ({}, {}, {})

    # Filter for relevant labels
    df_filtered = df[df['label'].isin(['SUPPORTS', 'REFUTES'])].copy()
    logger.info(f"Filtered to {len(df_filtered)} SUPPORTS/REFUTES claims.")

    # Limit samples if specified
    if max_samples and len(df_filtered) > max_samples:
        df_filtered = df_filtered.sample(n=max_samples, random_state=42)
        logger.info(f"Sampled down to {len(df_filtered)} examples.")

    examples = []
    queries = {}  # query_id -> query_text (for evaluation)
    corpus = {}   # doc_id -> doc_text (for evaluation)
    relevant_docs = {} # query_id -> set(doc_ids) (for evaluation)
    doc_id_counter = 0

    for index, row in df_filtered.iterrows():
        claim = str(row['claim']).strip()
        label = row['label']
        claim_id = f"claim_{row.get('id', index)}" # Use provided ID or index

        try:
            # Safely parse the 'evidence_items' string
            evidence_list = ast.literal_eval(str(row['evidence_items']))
            if not isinstance(evidence_list, list):
                raise ValueError("Parsed evidence_items is not a list")

            positive_found = False
            for evidence_item in evidence_list:
                 # Ensure evidence_item is a list/tuple with at least 2 elements
                if isinstance(evidence_item, (list, tuple)) and len(evidence_item) >= 2:
                    if verbose:
                        print(f"Processing evidence item: {evidence_item}")
                    sentence_text = str(evidence_item[0]).strip()
                    page_title = str(evidence_item[1]).strip()

                    if sentence_text and page_title:
                        positive_text = f"{sentence_text} {page_title}"
                        if verbose:
                            print(f"Positive text: {positive_text}")

                        if not is_eval_set:
                            # Create training examples (pairs for MNRL)
                            examples.append(InputExample(texts=[claim, positive_text]))
                            positive_found = True
                        else:
                            # Create evaluation data structure
                            queries[claim_id] = claim # Store claim as query

                            # Unique ID for this evidence sentence + title combination
                            doc_content = positive_text
                            # Simple way to avoid duplicates in corpus for the *same evidence*
                            # This assigns a NEW doc_id for each occurrence, which is needed for IR evaluator
                            # if claim text is unique per evaluation row.
                            current_doc_id = f"doc_{doc_id_counter}"
                            corpus[current_doc_id] = doc_content
                            doc_id_counter += 1

                            if claim_id not in relevant_docs:
                                relevant_docs[claim_id] = set()
                            relevant_docs[claim_id].add(current_doc_id) # Mark this doc as relevant for this claim
                            positive_found = True
                else:
                    logger.warning(f"Skipping malformed evidence item in row {index}: {evidence_item}")
                    pass # Keep processing other items

            # Optional: Add hard negatives (more complex)
            """
            if not is_eval_set and positive_found and 'full_text' in row:
                try:
                    full_text = str(row['full_text'])
                    page_sentences = nltk.sent_tokenize(full_text)
                    positive_sentences = {str(item[0]).strip() for item in evidence_list if isinstance(item, (list, tuple)) and len(item) > 0}

                    potential_negatives = [s.strip() for s in page_sentences if s.strip() and s.strip() not in positive_sentences]
                    if potential_negatives:
                      hard_negative_sentence = random.choice(potential_negatives)
                      # Assume negative uses the *same* page title as one of the positives
                      # This requires selecting one specific positive's title if multiple exist.
                      # Simplification: Use the title from the first valid positive item
                      first_valid_title = next((str(item[1]).strip() for item in evidence_list if isinstance(item, (list, tuple)) and len(item) >= 2 and str(item[0]).strip() and str(item[1]).strip()), None)
                      if first_valid_title:
                            hard_negative_text = f"{hard_negative_sentence} {first_valid_title}"
                            # Triplet format: examples.append(InputExample(texts=[claim, positive_text, hard_negative_text]))
                            # For MNRL, we can add as a negative pair (not directly supported in this format)
                            examples.append(InputExample(texts=[claim, hard_negative_text])) # Add as a negative pair
                except Exception as e_neg:
                    logger.warning(f"Could not generate hard negative for row {index}: {e_neg}")
            """
        except (ValueError, SyntaxError, TypeError) as e:
            logger.warning(f"Skipping row {index} due to error parsing evidence_items: {e}. Content: {row.get('evidence_items', 'N/A')}")
        except Exception as e:
             logger.error(f"Unexpected error processing row {index}: {e}")

    if is_eval_set:
        logger.info(f"Prepared evaluation data: {len(queries)} queries, {len(corpus)} corpus docs, {sum(len(v) for v in relevant_docs.values())} relevant pairs.")
        # Basic check for empty structures
        if not queries or not corpus or not relevant_docs:
             logger.warning("Evaluation data structures are empty or incomplete.")
        return queries, corpus, relevant_docs
    else:
        logger.info(f"Created {len(examples)} training pairs.")
        if not examples:
             logger.warning("No training examples were created. Check data and filtering.")
        return examples

# --- Load Data ---
logger.info("Loading training data...")
train_samples = load_and_prepare_data(TRAIN_CSV_PATH, is_eval_set=False, max_samples=MAX_SAMPLES, verbose=False) # Limit training set size if large

evaluator = None
if os.path.exists(DEV_CSV_PATH):
    logger.info("Loading development (evaluation) data...")
    dev_queries, dev_corpus, dev_relevant_docs = load_and_prepare_data(DEV_CSV_PATH, is_eval_set=True, max_samples=MAX_VALID_SAMPLES) # Limit dev set size if large
    if dev_queries and dev_corpus and dev_relevant_docs:
        evaluator = InformationRetrievalEvaluator(dev_queries, dev_corpus, dev_relevant_docs,
                                                name='fever-dev',
                                                show_progress_bar=True,
                                                write_csv=True
        )
        logger.info("InformationRetrievalEvaluator created for development set.")
    else:
         logger.warning("Could not create evaluator due to missing/empty dev data structures.")
else:
    logger.info("No development set specified or found. Skipping evaluation during training.")


# --- Model & Training Setup ---
if train_samples: # Proceed only if training data was loaded successfully
    logger.info(f"Loading pre-trained model: {MODEL_NAME}")
    # Use models.Transformer to ensure we can add pooling layer if needed, though MPNet usually has one.
    word_embedding_model = models.Transformer(MODEL_NAME)
    pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
    # No dense layer needed unless changing output dimensions
    model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

    logger.info(f"Using MultipleNegativesRankingLoss")
    loss = losses.MultipleNegativesRankingLoss(model=model)

    # DataLoader
    # Ensure shuffling for training data
    train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=TRAIN_BATCH_SIZE)
    train_loss = losses.MultipleNegativesRankingLoss(model=model)
    logger.info(f"Training batch size: {TRAIN_BATCH_SIZE}")


    # Calculate total steps and warmup steps if not fixed
    num_training_steps = int(len(train_dataloader) * NUM_EPOCHS)
    warmup_steps = math.ceil(num_training_steps * 0.1) # Alternative: 10% of total steps
    #warmup_steps = WARMUP_STEPS

    logger.info("Starting model training...")
    model.fit(train_objectives=[(train_dataloader, loss)],
              epochs=NUM_EPOCHS,
              optimizer_params={'lr': LEARNING_RATE},
              warmup_steps=warmup_steps,
              evaluator=evaluator,
              evaluation_steps=EVAL_STEPS if evaluator else 0, # Only evaluate if evaluator exists
              output_path=OUTPUT_PATH,
              checkpoint_path=os.path.join(OUTPUT_PATH, 'checkpoints'),
              checkpoint_save_steps=EVAL_STEPS * 2 if evaluator else 1000, # Save checkpoints periodically
              checkpoint_save_total_limit=3, # Keep only the last few checkpoints
              show_progress_bar=True)

    logger.info(f"Training complete. Model saved to: {OUTPUT_PATH}")

    # Optional: Save the final model in a new directory
    os.makedirs(OUTPUT_PATH, exist_ok=True)
    model.save(OUTPUT_PATH)
    if verbose:
        print(f"Final model saved to: {OUTPUT_PATH}")
    logger.info(f"Final model explicitly saved to: {OUTPUT_PATH}")

else:
    logger.error("Cannot start training because no training samples were loaded/prepared.")

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss,Validation Loss,Fever-dev Cosine Accuracy@1,Fever-dev Cosine Accuracy@3,Fever-dev Cosine Accuracy@5,Fever-dev Cosine Accuracy@10,Fever-dev Cosine Precision@1,Fever-dev Cosine Precision@3,Fever-dev Cosine Precision@5,Fever-dev Cosine Precision@10,Fever-dev Cosine Recall@1,Fever-dev Cosine Recall@3,Fever-dev Cosine Recall@5,Fever-dev Cosine Recall@10,Fever-dev Cosine Ndcg@10,Fever-dev Cosine Mrr@10,Fever-dev Cosine Map@100
54,No log,No log,1.0,1.0,1.0,1.0,1.0,0.462963,0.3,0.161111,0.829475,0.960648,0.987654,1.0,0.996589,1.0,0.993056
64,No log,No log,1.0,1.0,1.0,1.0,1.0,0.462963,0.3,0.161111,0.829475,0.960648,0.987654,1.0,0.996589,1.0,0.993056
108,No log,No log,1.0,1.0,1.0,1.0,1.0,0.462963,0.3,0.161111,0.829475,0.960648,0.987654,1.0,0.996589,1.0,0.993056


Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:17<00:00, 17.49s/it]


Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:16<00:00, 16.61s/it]


Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:19<00:00, 19.65s/it]


Final model saved to: /content/drive/My Drive/SUNY_Poly_DSA598/models/sBERT/sentence-transformers-all-mpnet-base-v2_n1024_04-20_0422
