In [None]:
import torch

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report

from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer

In [None]:
path = 'cat_emails_v2(in).csv'

with open(path, 'r', encoding="utf-8", errors="replace") as f:
    raw_lines = f.readlines()

records = []
in_record = False
current_category = None
email_parts = []

for line in raw_lines[1:]:

    if not in_record:
        if not line.strip():
            continue

        if line.lstrip().startswith('"'):
            in_record = True
            email_parts = []

            l = line.strip()
            if l.startswith('"'):
                l = l[1:]

            idx = l.find(',')
            if idx == -1:

                in_record = False
                current_category = None
                email_parts = []
                continue

            current_category = l[:idx].strip()

            rest = l[idx+1:]
            email_parts.append(rest)

    else:
        email_parts.append(line)

        if line.strip().endswith('"""'):

            full_text = "".join(email_parts)

            full_text_clean = (
                full_text.replace('"""', '')
                         .replace('""', '"')
                         .replace('"', '')
                         .strip()
            )

            records.append((current_category, full_text_clean))

            in_record = False
            current_category = None
            email_parts = []


data = pd.DataFrame(records, columns = ['category', 'email_text'])

data.to_csv('clean_emails.csv', index = False)

print('Parsed emails: ', len(data))

In [None]:
data

In [None]:
len(data['category'].value_counts())

In [None]:
data['label'] = data['category'].astype('category').cat.codes

train_data, test_data = train_test_split(data, test_size = 500, random_state = 42, stratify = data['label'])
train_dataset = Dataset.from_pandas(train_data)
test_dataset  = Dataset.from_pandas(test_data)

model_name = 'distilbert-base-german-cased'
tokenizer = AutoTokenizer.from_pretrained(model_name)

def tokenize(batch):
    return tokenizer(batch['email_text'], truncation = True, padding = 'max_length', max_length = 256)

train_dataset = train_dataset.map(tokenize, batched = True)
test_dataset = test_dataset.map(tokenize, batched = True)

train_dataset.set_format(type = 'torch', columns = ['input_ids', 'attention_mask', 'label'])
test_dataset.set_format(type = 'torch', columns = ['input_ids', 'attention_mask', 'label'])

model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels = data['label'].nunique())

device = torch.device('cuda')
model.to(device)

args = TrainingArguments(
    output_dir = 'base_email_model',
    learning_rate = 2e-5,
    per_device_train_batch_size = 8,
    per_device_eval_batch_size = 8,
    num_train_epochs = 10,
    weight_decay = 0.01,
    logging_steps = 50,
    save_strategy = 'epoch',
    eval_strategy = 'epoch',
    load_best_model_at_end = True,
    report_to = 'none'
)

trainer = Trainer(
    model = model,
    args = args,
    train_dataset = train_dataset,
    eval_dataset = test_dataset,
)

trainer.train()
trainer.save_model('base_email_classifier')

In [None]:
predictions = trainer.predict(test_dataset)
preds = np.argmax(predictions.predictions, axis = 1)
labels = predictions.label_ids

label_to_category = dict(enumerate(data['category'].astype('category').cat.categories))

print(classification_report(labels, preds, target_names = list(label_to_category.values())))

In [None]:
test_texts = test_dataset.to_pandas()['email_text'].tolist()
test_true_labels = test_dataset.to_pandas()['label'].tolist()

errors = [ ]

for i, (true, pred) in enumerate(zip(test_true_labels, preds)):
    if true != pred:
        errors.append({'true': label_to_category[true], 'pred': label_to_category[pred]})

errors_data = pd.DataFrame(errors)

print('Sample of model mistakes (top 20)')
print(errors_data.head(20))

In [None]:
plt.figure(figsize = (18, 14))
sns.heatmap(confusion_matrix(labels, preds), annot = True, cmap = 'Blues', xticklabels = label_to_category.values(), yticklabels = label_to_category.values())
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.xticks(rotation = 90)
plt.yticks(rotation = 0)
plt.show()