In [None]:
!pip install transformers

In [None]:
from transformers import MarianMTModel, MarianTokenizer
import pandas as pd
import torch

# Load backtranslation model and tokenizer
def load_translation_model(src_lang, tgt_lang):
    model_name = f"Helsinki-NLP/opus-mt-{src_lang}-{tgt_lang}"
    tokenizer = MarianTokenizer.from_pretrained(model_name)
    model = MarianMTModel.from_pretrained(model_name)
    return model, tokenizer

# Translate text in batches
def translate_text(texts, model, tokenizer, device='cpu', batch_size=8):
    model = model.to(device)
    translations = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True).to(device)
        outputs = model.generate(**inputs)
        decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        translations.extend(decoded)
    return translations

# Backtranslate texts
def backtranslate_texts(texts, device='cpu'):
    # Load translation models
    en_to_fr_model, en_to_fr_tokenizer = load_translation_model('en', 'fr')
    fr_to_en_model, fr_to_en_tokenizer = load_translation_model('fr', 'en')

    # Step 1: Translate to French
    translated_to_fr = translate_text(texts, en_to_fr_model, en_to_fr_tokenizer, device)

    # Step 2: Translate back to English
    backtranslated = translate_text(translated_to_fr, fr_to_en_model, fr_to_en_tokenizer, device)

    return backtranslated

# Load dataset
dataset_path = '/content/labeled_data_cleaned_whole.csv'
data = pd.read_csv(dataset_path)

# Clean missing values
data['corrected_tweet'] = data['corrected_tweet'].fillna('')
data['corrected_tweet'] = data['corrected_tweet'].astype(str)

# Separate the hate speech class (class 0)
class_0 = data[data['class'] == 0]

# Augment class 0 with backtranslation
device = 'cuda' if torch.cuda.is_available() else 'cpu'
augmented_texts = backtranslate_texts(class_0['corrected_tweet'].tolist(), device)

# Create a new DataFrame for augmented data
augmented_class_0 = pd.DataFrame({
    'corrected_tweet': augmented_texts,
    'class': [0] * len(augmented_texts)
})

# Combine augmented data with the original dataset
augmented_data = pd.concat([data, augmented_class_0])
augmented_data = augmented_data.sample(frac=1, random_state=42)

# Save the augmented dataset
augmented_dataset_path = '/content/augmented_dataset.csv'
augmented_data.to_csv(augmented_dataset_path, index=False)
print(f"Augmented dataset saved to {augmented_dataset_path}")
