In [12]:
# Check if GPU is available (highly recommended for faster training)
import torch
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
    print("Using CPU. Training might be slow.")

# 1. Install necessary libraries
!pip install -qq transformers datasets accelerate scikit-learn pandas tabulate groq

# Imports
from datasets import load_dataset, DatasetDict, Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import pandas as pd
import numpy as np
from tabulate import tabulate # Import tabulate for pretty printing

# Set a seed for reproducibility
import random
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

Using CPU. Training might be slow.


In [13]:
ds = load_dataset("infinite-dataset-hub/CorporateMailCategorization")

# Convert to pandas DataFrame for easier null handling
df = ds["train"].to_pandas()

In [14]:
initial_rows = len(df)
df_cleaned = df.dropna(subset=['label']).copy()
rows_dropped = initial_rows - len(df_cleaned)
print(f"\nDropped {rows_dropped} rows with null 'label' values.")
print(f"Remaining rows after dropping nulls: {len(df_cleaned)}")
df_cleaned


Dropped 10 rows with null 'label' values.
Remaining rows after dropping nulls: 90


In [16]:
is_duplicate_counts = df_cleaned.duplicated().value_counts()
print(is_duplicate_counts)

False    90
Name: count, dtype: int64


In [17]:
import matplotlib.pyplot as plt
import pandas as pd
from datasets import load_dataset

label_counts = df_cleaned['label'].value_counts()
label_counts

Unnamed: 0_level_0,count
label,Unnamed: 1_level_1
Employee Feedback,21
Merger Announcement,21
Sustainability Initiative,21
Financial Report,12
Product Launch,6
Financial Projections,1
Financial Performance,1
Preliminary Financial Report,1
Audit Request,1
Financial Analysis,1


In [19]:
# Important: Map string labels to integers for the model
# Get unique labels from the cleaned dataset
unique_labels = sorted(df_cleaned['label'].unique().tolist())
label_to_id = {label: i for i, label in enumerate(unique_labels)}
id_to_label = {i: label for i, label in enumerate(unique_labels)}

print(f"\nDetected labels and their mappings: {label_to_id}")
num_labels = len(unique_labels)
print(f"Number of unique labels: {num_labels}")

# Apply label mapping to the cleaned dataset
def map_labels_to_ids(example):
    example['label'] = label_to_id[example['label']]
    return example

# Convert cleaned DataFrame to Hugging Face Dataset
full_labeled_ds = Dataset.from_pandas(df_cleaned).map(map_labels_to_ids)


Detected labels and their mappings: {'Audit Request': 0, 'Budget Report': 1, 'Employee Feedback': 2, 'Financial Analysis': 3, 'Financial Health': 4, 'Financial Performance': 5, 'Financial Projections': 6, 'Financial Report': 7, 'Merger Announcement': 8, 'Preliminary Financial Report': 9, 'Product Launch': 10, 'Profit Analysis': 11, 'Summary Financial Report': 12, 'Sustainability Initiative': 13}
Number of unique labels: 14


Map:   0%|          | 0/90 [00:00<?, ? examples/s]

In [20]:
import random
import os
import json
from groq import Groq

from google.colab import userdata
os.environ["GROQ_API_KEY"] = userdata.get('GROQ_API_KEY')

class SyntheticTextGenerator:
    def __init__(self, id_to_label_map):
        self.client = Groq()
        self.groq_model = "llama3-8b-8192"
        self.id_to_label = id_to_label_map # Store the mapping for better prompts

    def generate_similar_texts(self, reference_texts: list, num_to_generate: int, target_label_id: int) -> list:
        """
        Generates new text examples similar to reference_texts for a given label,
        expecting JSON output from the LLM.
        """
        target_label_name = self.id_to_label.get(target_label_id, "Unknown Category")

        if not reference_texts:
            print(f"Warning: No reference texts provided for label ID {target_label_id} ({target_label_name}). Cannot generate examples.")
            return []

        # Prepare reference examples for the prompt
        reference_json_examples = []
        for text in reference_texts:
            # Ensure reference examples are also clean JSON lines
            reference_json_examples.append(json.dumps({"text": text, "label": target_label_name}))
        reference_str = ",\n".join(reference_json_examples)

        # Adjusted system prompt for more robust JSON output instructions
        system_prompt = (
            "You are an AI assistant specialized in generating realistic and diverse corporate email snippets "
            "for data augmentation. Your output MUST be a strict JSON array where each element is a JSON object with 'text' and 'label' keys. "
            "It must start with `[` and end with `]`. "
            "Do NOT include any additional text, explanations, or formatting outside of the JSON array. "
            "Do NOT wrap the array in any other JSON object, like {\"data\": [...]}. Provide ONLY the JSON array. "
            "Ensure the generated emails are plausible and distinct from the references."
            "Always wrap your JSON output in triple backticks with 'json' language specifier, e.g., ```json [...]```."
        )

        user_prompt = f"""
        Generate exactly {num_to_generate} new and distinct corporate email snippets.
        Each snippet must be categorized as '{target_label_name}'.

        Here are some existing examples for reference, categorized as '{target_label_name}'.
        These are provided as a JSON array of objects, with 'text' and 'label' keys:
        ```json
        [
        {reference_str}
        ]
        ```

        Produce your output as a single JSON array, directly.
        Each element in the array must be a JSON object with two keys:
        1. "text": The generated email snippet.
        2. "label": The category, which must be '{target_label_name}'.

        Output ONLY the JSON array, and wrap it in triple backticks with the 'json' language specifier (e.g., ```json [...]``` ).
        For example:
        ```json
        [
          {{"text": "Generated email snippet 1", "label": "{target_label_name}"}},
          {{"text": "Generated email snippet 2", "label": "{target_label_name}"}}
        ]
        ```
        """

        json_output_str = None # Initialize json_output_str

        try:
            chat_completion = self.client.chat.completions.create(
                messages=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": user_prompt}
                ],
                model=self.groq_model,
                temperature=0.8,
                max_tokens=int(num_to_generate * 100)
            )

            # Extract content and try to clean it
            json_output_str = chat_completion.choices[0].message.content.strip()
            # print(json_output_str)

            if json_output_str.startswith("```") :
                cleaned_json_str = json_output_str[len("```"):].strip()
                print('check 0')
            if cleaned_json_str.startswith("json") :
                cleaned_json_str = cleaned_json_str[len("json"):].strip()
                print('check 1')
            if cleaned_json_str.endswith("```") :
                cleaned_json_str = cleaned_json_str[:-len("```")].strip()
                print('check 2')
            else:
                # If it didn't use backticks, try parsing directly as the model might still try to output raw JSON
                # cleaned_json_str = json_output_str
                print('check 3')
            # print('out', cleaned_json_str)

            parsed_raw = None
            try:
                parsed_raw = json.loads(cleaned_json_str)
            except json.JSONDecodeError as de:
                print(f"Secondary JSON parsing failed: {de}")
                raise json.JSONDecodeError("Failed to parse JSON after initial cleaning attempts.", cleaned_json_str, 0)


            # After parsing (which might have been a dict or a list)
            generated_data_list = []
            if isinstance(parsed_raw, dict) and 'data' in parsed_raw and isinstance(parsed_raw['data'], list):
                generated_data_list = parsed_raw['data']
                print("Note: LLM output was wrapped in 'data' key, extracted successfully.")
            elif isinstance(parsed_raw, list):
                generated_data_list = parsed_raw
            else:
                raise ValueError(f"LLM output is not a direct JSON array or wrapped in a 'data' key as expected. Type: {type(parsed_raw)}")

            # Validate the structure of each item
            validated_examples = []
            for item in generated_data_list:
                if isinstance(item, dict) and 'text' in item and 'label' in item:
                    if item['label'] == target_label_name:
                        validated_examples.append(item)
                    else:
                        print(f"Warning: Generated label '{item['label']}' does not match target '{target_label_name}'. Skipping.")
                else:
                    print(f"Warning: Invalid item structure received from LLM: {item}. Skipping.")
            return validated_examples

        except json.JSONDecodeError as e:
            print(f"Error parsing JSON from Groq API for {target_label_name}: {e}")
            print(f"Raw LLM output (for debugging): {json_output_str}")
            return []
        except Exception as e:
            print(f"Error calling Groq API or processing response for {target_label_name}: {e}")
            print(f"Raw LLM output (if available): {json_output_str}")
            return []


# Instantiate the synthetic data generator
generator = SyntheticTextGenerator(id_to_label)

label_counts = df_cleaned['label'].value_counts()
under_represented_label_names = label_counts[label_counts < 15].index.tolist()

augmented_data = []
temp = ['Financial Report']

for label_name in under_represented_label_names:
    print(f"\nAugmenting label: '{label_name}'")
    current_count = label_counts[label_name]

    # Get existing examples for reference
    reference_df = df_cleaned[df_cleaned['label'] == label_name]

    num_references = min(5, len(reference_df)) # Use up to 5 references
    reference_texts = random.sample(reference_df['text'].tolist(), num_references)

    # Determine how many new examples are needed to reach TARGET
    num_to_generate = 15 - current_count

    print(f"  Need to generate {num_to_generate} new examples for '{label_name}'.")

    label_id = label_to_id[label_name]
    new_generated_examples = generator.generate_similar_texts(reference_texts, num_to_generate, label_id)

    augmented_data.extend(new_generated_examples)
    print(f"  Generated {len(new_generated_examples)} new examples for '{label_name}'.")

# Now, 'augmented_data' contains a list of dictionaries, each with 'text' and 'label'
print(f"\n--- Overall Generated Synthetic Data ---")
print(f"Total new synthetic examples generated across all under-represented labels: {len(augmented_data)}")
if augmented_data:
    # Print a sample of the generated data to verify format
    print("\nSample of generated data (first 3):")
    for i, ex in enumerate(augmented_data[:3]):
        print(f"  {i+1}. Text: '{ex['text'][:70]}...', Label: '{ex['label']}'")


Augmenting label: 'Financial Report'
  Need to generate 3 new examples for 'Financial Report'.
check 0
check 1
check 3
  Generated 3 new examples for 'Financial Report'.

Augmenting label: 'Product Launch'
  Need to generate 9 new examples for 'Product Launch'.
check 0
check 1
check 3
  Generated 9 new examples for 'Product Launch'.

Augmenting label: 'Financial Projections'
  Need to generate 14 new examples for 'Financial Projections'.
check 0
check 1
check 3
  Generated 14 new examples for 'Financial Projections'.

Augmenting label: 'Financial Performance'
  Need to generate 14 new examples for 'Financial Performance'.
check 0
check 1
check 3
  Generated 14 new examples for 'Financial Performance'.

Augmenting label: 'Preliminary Financial Report'
  Need to generate 14 new examples for 'Preliminary Financial Report'.
check 0
check 1
check 3
  Generated 14 new examples for 'Preliminary Financial Report'.

Augmenting label: 'Audit Request'
  Need to generate 14 new examples for 'Audi

In [21]:
# for content  in augmented_data :
#     print(f"email : {content['text']} & label : {content['label']}")

In [22]:
if augmented_data:
    augmented_df = pd.DataFrame(augmented_data)
    df_augmented = pd.concat([df_cleaned, augmented_df], ignore_index=True)
    print(f"Total rows after augmentation: {len(df_augmented)}")
    df_cleaned = df_augmented.copy()
else:
    print("\nNo data augmentation performed as no under-represented labels were found or generation failed.")

Total rows after augmentation: 228


In [23]:
def map_labels_to_ids_for_augmented(row):
    # Check if the 'label' in the row is a string before attempting to map
    if isinstance(row['label'], str):
        row['label'] = label_to_id[row['label']]
    return row

# Convert df_cleaned (now potentially df_augmented) to a Hugging Face Dataset
# Use .apply(..., axis=1) to apply the function row-wise
df_cleaned = df_cleaned.apply(map_labels_to_ids_for_augmented, axis=1)

# Now, convert the DataFrame to a Hugging Face Dataset
from datasets import Dataset

# Assuming your DataFrame has a 'text' and 'label' column
full_labeled_ds = Dataset.from_pandas(df_cleaned)

print("Final Dataset object created successfully:")
print(full_labeled_ds)

Final Dataset object created successfully:
Dataset({
    features: ['idx', 'text', 'label'],
    num_rows: 228
})


In [24]:
full_labeled_ds=full_labeled_ds.remove_columns(["idx"])

In [25]:
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)

# Move model to GPU if available
model.to(device)


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): DistilBertSdpaAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)


In [26]:
def compute_metrics(p):
    predictions = np.argmax(p.predictions, axis=1)
    labels = np.array(p.label_ids).flatten()

    accuracy = accuracy_score(labels, predictions)
    f1_macro = f1_score(labels, predictions, average='macro', zero_division=0)
    precision_macro = precision_score(labels, predictions, average='macro', zero_division=0)
    recall_macro = recall_score(labels, predictions, average='macro', zero_division=0)


    metrics = {
        'accuracy': accuracy,
        'f1_weighted': f1_macro,
        'precision_weighted': precision_macro,
        'recall_weighted': recall_macro,
    }
    return metrics


In [27]:
train_val_test_split = full_labeled_ds.train_test_split(test_size=0.1, seed=42)
train_val_ds = train_val_test_split["train"]
test_ds_final = train_val_test_split["test"] # This is our final, labeled test set

train_val_split = train_val_ds.train_test_split(test_size=0.11, seed=42)
train_ds_split = train_val_split["train"]
eval_ds_split = train_val_split["test"] # This is our validation set

print(f"\nDataset Splits:")
print(f"  Training samples: {len(train_ds_split)}")
print(f"  Validation samples: {len(eval_ds_split)}")
print(f"  Final Test samples: {len(test_ds_final)}")

print("\nTrain Dataset Split Structure:", train_ds_split)
print("Validation Dataset Split Structure:", eval_ds_split)
print("Final Test Dataset Structure:", test_ds_final)


Dataset Splits:
  Training samples: 182
  Validation samples: 23
  Final Test samples: 23

Train Dataset Split Structure: Dataset({
    features: ['text', 'label'],
    num_rows: 182
})
Validation Dataset Split Structure: Dataset({
    features: ['text', 'label'],
    num_rows: 23
})
Final Test Dataset Structure: Dataset({
    features: ['text', 'label'],
    num_rows: 23
})


In [28]:
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)

tokenized_train_ds = train_ds_split.map(tokenize_function, batched=True)
tokenized_eval_ds = eval_ds_split.map(tokenize_function, batched=True)
tokenized_test_ds_final = test_ds_final.map(tokenize_function, batched=True) # Tokenize final test set

# Remove original text and idx columns as they are no longer needed for training/evaluation
tokenized_train_ds = tokenized_train_ds.remove_columns(["text"])
tokenized_eval_ds = tokenized_eval_ds.remove_columns(["text"])
tokenized_test_ds_final = tokenized_test_ds_final.remove_columns(["text"]) # Keep label for final test

print(tokenized_train_ds)
# print(len(tokenized_eval_ds))
# print(len(tokenized_test_ds_final))

Map:   0%|          | 0/182 [00:00<?, ? examples/s]

Map:   0%|          | 0/23 [00:00<?, ? examples/s]

Map:   0%|          | 0/23 [00:00<?, ? examples/s]

Dataset({
    features: ['label', 'input_ids', 'attention_mask'],
    num_rows: 182
})


In [32]:
training_args = TrainingArguments(
    output_dir="./results_mail_category", # Directory for logs and checkpoints
    num_train_epochs=10,                  # Number of training epochs
    per_device_train_batch_size=16,      # Batch size for training
    per_device_eval_batch_size=16,       # Batch size for evaluation
    warmup_steps=10,                     # Number of warmup steps for learning rate scheduler
    weight_decay=0.01,                   # Strength of weight decay
    logging_dir="./logs_mail_category",  # Directory for storing logs
    logging_strategy="epoch",            # Log metrics at the end of each epoch
    save_strategy="epoch",               # Save model at the end of each epoch
    eval_strategy="epoch",               # Evaluate at the end of each epoch
    load_best_model_at_end=True,         # Load the best model at the end of training
    metric_for_best_model="f1_weighted", # Metric to use to compare models
    report_to="none",                    # Don't report to any online services
)


trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_ds,
    eval_dataset=tokenized_eval_ds,
    compute_metrics=compute_metrics,
    tokenizer=tokenizer
)

# Start training!
trainer.train()

print("\n--- 7. Evaluation Metrics on Validation Set ---")
eval_results = trainer.evaluate()
print("\nEvaluation Metrics on Validation Set (used during training):")
eval_df = pd.DataFrame([eval_results]).transpose()
eval_df.columns = ['Value']
print(tabulate(eval_df, headers='keys', tablefmt='grid', floatfmt=".4f"))

  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Accuracy,F1 Weighted,Precision Weighted,Recall Weighted
1,1.4633,1.51483,0.826087,0.754545,0.818182,0.772727
2,1.0585,1.221233,0.826087,0.705556,0.763889,0.708333
3,0.7551,0.933833,0.826087,0.705556,0.763889,0.708333
4,0.5397,0.741602,0.782609,0.635897,0.705128,0.628205
5,0.3763,0.693858,0.869565,0.75,0.791667,0.75
6,0.2841,0.678949,0.782609,0.653846,0.730769,0.641026
7,0.2046,0.603926,0.782609,0.653846,0.730769,0.641026
8,0.1718,0.516514,0.826087,0.676923,0.730769,0.666667
9,0.1426,0.536358,0.826087,0.676923,0.730769,0.666667
10,0.1267,0.498098,0.826087,0.676923,0.730769,0.666667



--- 7. Evaluation Metrics on Validation Set ---



Evaluation Metrics on Validation Set (used during training):
+-------------------------+---------+
|                         |   Value |
| eval_loss               |  1.5148 |
+-------------------------+---------+
| eval_accuracy           |  0.8261 |
+-------------------------+---------+
| eval_f1_weighted        |  0.7545 |
+-------------------------+---------+
| eval_precision_weighted |  0.8182 |
+-------------------------+---------+
| eval_recall_weighted    |  0.7727 |
+-------------------------+---------+
| eval_runtime            |  4.4463 |
+-------------------------+---------+
| eval_samples_per_second |  5.1730 |
+-------------------------+---------+
| eval_steps_per_second   |  0.4500 |
+-------------------------+---------+
| epoch                   | 10.0000 |
+-------------------------+---------+


In [33]:
print("\n--- Saving the Fine-tuned Model ---")
save_path = "./mail_category"
trainer.save_model(save_path)
tokenizer.save_pretrained(save_path)
print(f"\nModel and tokenizer saved to: {save_path}")


--- Saving the Fine-tuned Model ---

Model and tokenizer saved to: ./mail_category


In [34]:
print("\n--- Performing Inference and Evaluation on Final Test Set ---")
# Load the saved model to ensure we are using the fine-tuned one
loaded_tokenizer = AutoTokenizer.from_pretrained(save_path)
loaded_model = AutoModelForSequenceClassification.from_pretrained(save_path, num_labels=num_labels)
loaded_model.to(device) # Move to GPU if available

# Use the loaded model for prediction on the FINAL test set
final_test_trainer = Trainer(model=loaded_model, tokenizer=loaded_tokenizer)

# Predict on the tokenized final test dataset
predictions_final_test = final_test_trainer.predict(tokenized_test_ds_final)

# Get predicted class IDs and confidence scores
predicted_ids_final_test = np.argmax(predictions_final_test.predictions, axis=1)
confidence_scores_final_test = np.max(torch.softmax(torch.tensor(predictions_final_test.predictions), dim=1).numpy(), axis=1)

# Get actual labels from the original (non-tokenized) test_ds_final
# Note: test_ds_final['label'] should contain the numerical label IDs
actual_labels_final_test = [id_to_label[label_id] for label_id in test_ds_final['label']]

# Map predicted IDs back to original labels
predicted_labels_final_test = [id_to_label[id_val] for id_val in predicted_ids_final_test]

# Create a DataFrame for tabular output of actual vs. predicted
final_results_df = pd.DataFrame({
    'text': test_ds_final['text'],
    'actual_label': actual_labels_final_test,
    'predicted_label': predicted_labels_final_test,
    'confidence': confidence_scores_final_test
}).reset_index(drop=True)

print("\n--- Actual vs. Predicted Categories for Final Test Samples ---")
print(tabulate(final_results_df, headers='keys', tablefmt='grid', floatfmt=".4f"))

# Calculate and print metrics for the final test set
# The compute_metrics function will now return accuracy, f1_macro, precision_macro, recall_macro
final_test_metrics = compute_metrics(predictions_final_test)
print("\nEvaluation Metrics on Final Test Set:")
final_test_metrics_df = pd.DataFrame([final_test_metrics]).transpose()
final_test_metrics_df.columns = ['Value']
print(tabulate(final_test_metrics_df, headers='keys', tablefmt='grid', floatfmt=".4f"))

print("\n--- Fine-tuning process complete! ---")


--- Performing Inference and Evaluation on Final Test Set ---


  final_test_trainer = Trainer(model=loaded_model, tokenizer=loaded_tokenizer)



--- Actual vs. Predicted Categories for Final Test Samples ---
+----+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------+------------------------------+--------------+
|    | text                                                                                                                                                                                                                                                               | actual_label                 | predicted_label              |   confidence |
|  0 | Our latest budget report is now available for review. It includes a detailed breakdown of our current financial situation and future projections. - Michael Lee                                                                                