In [10]:
import os
import sys
import logging
import json
import random
import argparse
from datetime import datetime
from tqdm import tqdm
from transformers import pipeline

import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
import wandb

from transformers import GPT2Tokenizer, GPT2Config
sys.path.append(os.path.join('/gpfs/home/bsk18/factual-bias-mitigation/scripts/summarization/', '../../'))

# Assume these are defined in your project based on previous code
from src.contrastive_learning.contrastive_gpt2_architecture import ContrastiveGPT2Model  # Our adjusted model
from src.contrastive_learning.contrastive_dataset import ContrastiveTranslationDataset, contrastive_collate_fn  # Our dataset and collate function
from src.contrastive_learning.contrastive_loss import ContrastiveLoss  # Our custom loss function
from src.factuality_detector import FactualityDetector
from src.contrastive_learning.train_utils import (
    evaluate_toxicity  # Function to compute toxicity scores,
)
from src.contrastive_learning.utils import( 
    calculate_original_factuality,
    calculate_original_toxicity,
    load_indices,
    combine_data,
    read_files,
)

from sacrebleu import corpus_bleu
from rouge_score import rouge_scorer


# Set up logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

# Create a file handler
log_dir = 'logs'
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"training_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")

file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.INFO)

# Create a logging format
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)

# Console handler
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(formatter)

# Add handlers to the logger
logger.addHandler(file_handler)
logger.addHandler(console_handler)



In [11]:
# Process 'all' option
all_pos_options = ['original', 'bt']
all_neg_options = ['original', 'entity_swap', 'low_confidence']

pos_data = ['original', 'pos_bt']
neg_data = ['original']

# Ensure 'original' is always included
if 'original' not in pos_data:
    pos_data.insert(0, 'original')
if 'original' not in neg_data:
    neg_data.insert(0, 'original')

# Define data paths based on arguments
data_dir = '/gpfs/home/bsk18/factual-bias-mitigation/data/tldr'  # Update with your actual data directory
pos_data_paths = []
neg_data_paths = []

for opt in pos_data:
    if opt == 'original':
        pos_data_paths.append(os.path.join(data_dir, 'pos'))
    else:
        pos_data_paths.append(os.path.join(data_dir, f'pos_{opt}'))
for opt in neg_data:
    if opt == 'original':
        neg_data_paths.append(os.path.join(data_dir, 'neg'))
    else:
        neg_data_paths.append(os.path.join(data_dir, f'neg_{opt}'))

# Combine data from all specified paths
pos_data = combine_data([read_files(path, logger) for path in pos_data_paths])
neg_data = combine_data([read_files(path, logger) for path in neg_data_paths])

val_pos_data = read_files('/gpfs/home/bsk18/factual-bias-mitigation/data/tldr/validation/pos', logger, splits=['validation'])
val_neg_data = read_files('/gpfs/home/bsk18/factual-bias-mitigation/data/tldr/validation/neg', logger, splits=['validation'])

# Load index mappings
index_dir = '/gpfs/home/bsk18/factual-bias-mitigation/data/tldr/indices/pos_all_neg_original/'  # Update with the directory where index files are stored
train_pos_indices = load_indices(os.path.join(index_dir, 'train.positive.index'))
train_neg_indices = load_indices(os.path.join(index_dir, 'train.negative.index'))
val_pos_indices = load_indices(os.path.join(index_dir, 'validation.positive.index'))
val_neg_indices = load_indices(os.path.join(index_dir, 'validation.negative.index'))

# Initialize the factuality detector
factuality_detector = FactualityDetector("buseskorkmaz/factual-bias-mitigation-models")
logger.info("Initialized FactualityDetector")

model_name = 'gpt2'

# Initialize the model
if model_name == "gpt2":
    config = GPT2Config.from_pretrained('gpt2')
    config.output_hidden_states = True
    model = ContrastiveGPT2Model(config)
else:
    # Load other models as needed
    pass

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.padding_side = 'left'
tokenizer.pad_token = tokenizer.eos_token

# Prepare datasets
train_dataset = ContrastiveTranslationDataset(
    src_texts=pos_data['train']['source'],
    pos_texts=pos_data['train']['target'],
    neg_texts=neg_data['train']['target'],
    pos_indices=train_pos_indices,
    neg_indices=train_neg_indices,
    tokenizer=tokenizer,
    max_length=512,
    max_neg_samples=5,
    cl_seed=0
)

val_dataset = ContrastiveTranslationDataset(
    src_texts=val_pos_data['validation']['source'],
    pos_texts=val_pos_data['validation']['target'],
    neg_texts=val_neg_data['validation']['target'],
    pos_indices=val_pos_indices,
    neg_indices=val_neg_indices,
    tokenizer=tokenizer,
    max_length=512,
    max_neg_samples=5,
    cl_seed=0
)

# Create DataLoaders
train_loader = DataLoader(
    train_dataset, batch_size=2, shuffle=True, collate_fn=contrastive_collate_fn
)
val_loader = DataLoader(
    val_dataset, batch_size=2, collate_fn=contrastive_collate_fn
)


2024-11-02 14:43:56,850 - INFO - Loaded data from /gpfs/home/bsk18/factual-bias-mitigation/data/tldr/pos
2024-11-02 14:43:56,852 - INFO - Couldn't find /gpfs/home/bsk18/factual-bias-mitigation/data/tldr/pos_pos_bt
2024-11-02 14:43:56,888 - INFO - Loaded data from /gpfs/home/bsk18/factual-bias-mitigation/data/tldr/neg
2024-11-02 14:43:56,891 - INFO - Loaded data from /gpfs/home/bsk18/factual-bias-mitigation/data/tldr/validation/pos
2024-11-02 14:43:56,899 - INFO - Loaded data from /gpfs/home/bsk18/factual-bias-mitigation/data/tldr/validation/neg


skipping validation, couldn't find
skipping train, couldn't find
skipping validation, couldn't find


KeyboardInterrupt: 