In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.optim import AdamW
import torch
import pandas as pd
from sklearn.metrics import classification_report
from sklearn.preprocessing import LabelEncoder
from torch.amp import GradScaler, autocast
from tqdm import tqdm
import numpy as np
import random
import transformers, sklearn, platform, sys
import time

seed = 677
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

tokenizer_model = "roberta-base"
code = "base_model"
training_group = "whole"  # "males", "females", "whole"
epochs = 2
use_amp = True

training_df = pd.read_csv("BIOS_train.csv")
validation_df = pd.read_csv("BIOS_val.csv")
test_df = pd.read_csv("BIOS_test.csv")

X_train = training_df.drop(columns='title')
y_train = training_df['title']
X_valid = validation_df.drop(columns='title')
y_valid = validation_df['title']
X_test = test_df.drop(columns='title')
y_test = test_df['title']

label_encoder = LabelEncoder()
label_encoder.fit(y_train)
print("\nLabel Encoding Map (class index --> occupation title):")
for idx, label in enumerate(label_encoder.classes_):
    print(f"{idx}: {label}")



y_train = label_encoder.transform(y_train)
y_valid = label_encoder.transform(y_valid)
y_test = label_encoder.transform(y_test)


X_train_males = X_train[X_train['gender'] == 'M'].copy()
y_train_males = y_train[X_train['gender'] == 'M']
X_train_females = X_train[X_train['gender'] == 'F'].copy()
y_train_females = y_train[X_train['gender'] == 'F']

X_valid_males = X_valid[X_valid['gender'] == 'M'].copy()
y_valid_males = y_valid[X_valid['gender'] == 'M']
X_valid_females = X_valid[X_valid['gender'] == 'F'].copy()
y_valid_females = y_valid[X_valid['gender'] == 'F']

X_test_males = X_test[X_test['gender'] == 'M'].copy()
y_test_males = y_test[X_test['gender'] == 'M']
X_test_females = X_test[X_test['gender'] == 'F'].copy()
y_test_females = y_test[X_test['gender'] == 'F']


assert X_train_males['gender'].nunique() == 1 and X_train_males['gender'].iloc[0] == 'M', "X_train_males has unexpected gender values"
assert X_train_females['gender'].nunique() == 1 and X_train_females['gender'].iloc[0] == 'F', "X_train_females has unexpected gender values"
assert X_valid_males['gender'].nunique() == 1 and X_valid_males['gender'].iloc[0] == 'M', "X_valid_males has unexpected gender values"
assert X_valid_females['gender'].nunique() == 1 and X_valid_females['gender'].iloc[0] == 'F', "X_valid_females has unexpected gender values"
assert X_test_males['gender'].nunique() == 1 and X_test_males['gender'].iloc[0] == 'M', "X_test_males has unexpected gender values"
assert X_test_females['gender'].nunique() == 1 and X_test_females['gender'].iloc[0] == 'F', "X_test_females has unexpected gender values"
print("Gender splits: all subsets contain only the expected gender")

for name, X_split, y_split in [
    ("X_train_males", X_train_males, y_train_males),
    ("X_train_females", X_train_females, y_train_females),
    ("X_valid_males", X_valid_males, y_valid_males),
    ("X_valid_females", X_valid_females, y_valid_females),
    ("X_test_males", X_test_males, y_test_males),
    ("X_test_females", X_test_females, y_test_females),
]:
    assert len(X_split) == len(y_split), f"Mismatch in {name} and corresponding labels"
    print(f"{name}: {len(X_split)} samples")

print("All splits are valid and properly structured.")



if training_group == "females":
    X_train_group, y_train_group = X_train_females['bio'], y_train_females
    X_valid_group, y_valid_group = X_valid_females['bio'], y_valid_females
elif training_group == "males":
    X_train_group, y_train_group = X_train_males['bio'], y_train_males
    X_valid_group, y_valid_group = X_valid_males['bio'], y_valid_males
elif training_group == "whole":
    X_train_group, y_train_group = X_train['bio'], y_train
    X_valid_group, y_valid_group = X_valid['bio'], y_valid
else:
    raise ValueError(f"Invalid training_group '{training_group}'.")

print(f"[INFO] Training group: {training_group} — Number of training examples: {len(X_train_group)}")

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(tokenizer_model)

def tokenize_function(texts):
    return tokenizer(list(texts), padding="max_length", truncation=True, max_length=160, return_tensors="pt")

class BIOS_dataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        return {
            'input_ids': self.encodings['input_ids'][idx],
            'attention_mask': self.encodings['attention_mask'][idx],
            'labels': torch.tensor(self.labels[idx], dtype=torch.long)
        }

    def __len__(self):
        return len(self.labels)


train_dataset = BIOS_dataset(tokenize_function(X_train_group), y_train_group)
valid_dataset = BIOS_dataset(tokenize_function(X_valid_group), y_valid_group)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=32)

# Model
num_labels = len(label_encoder.classes_)
model = AutoModelForSequenceClassification.from_pretrained(tokenizer_model, num_labels=num_labels)

if tokenizer_model == "bert-base-uncased":
    lr = 2e-5
elif tokenizer_model == "roberta-base":
    lr = 2e-5
elif tokenizer_model == "distilroberta-base":
    lr = 5e-5

optimizer = AdamW(model.parameters(), lr=lr)
scaler = GradScaler()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
print(f"\nUsing device: {device}")


start_time = time.time()
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        if use_amp:
            with autocast(device_type='cuda'):
                outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch + 1} - Avg Training Loss: {avg_loss:.4f}")

    # Validation
    model.eval()
    predictions, true_labels = [], []
    with torch.no_grad():
        for batch in tqdm(valid_loader, desc=f"Validation Epoch {epoch + 1}"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids, attention_mask=attention_mask)
            preds = torch.argmax(outputs.logits, dim=-1)

            predictions.extend(preds.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())

    decoded_true = label_encoder.inverse_transform(true_labels)
    decoded_preds = label_encoder.inverse_transform(predictions)

    print(f"\nValidation Report (Epoch {epoch + 1}):")
    print(classification_report(decoded_true, decoded_preds, target_names=label_encoder.classes_))
end_time = time.time()
print(f"Training took {end_time - start_time:.2f} seconds")

# Testing on the 3 splits
test_groups = {
    "whole": (X_test['bio'], y_test),
    "males": (X_test_males['bio'], y_test_males),
    "females": (X_test_females['bio'], y_test_females)
}

print(f"\n*** Used code: {code}. Training group: {training_group}. Model: {tokenizer_model}. Seed: {seed} ***")

for group_name, (X_group, y_group) in test_groups.items():
    print(f"\nTesting on {group_name.upper()} group")
    test_dataset = BIOS_dataset(tokenize_function(X_group), y_group)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32)

    predictions, true_labels = [], []
    model.eval()
    with torch.no_grad():
        for batch in tqdm(test_loader, desc=f"Testing {group_name}"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids, attention_mask=attention_mask)
            preds = torch.argmax(outputs.logits, dim=-1)

            predictions.extend(preds.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())

    decoded_true = label_encoder.inverse_transform(true_labels)
    decoded_preds = label_encoder.inverse_transform(predictions)

    print("\nClassification Report:")
    print(classification_report(decoded_true, decoded_preds, target_names=label_encoder.classes_))

    # SAVE TO EXCEL
    report_dict = classification_report(
        decoded_true,
        decoded_preds,
        target_names=label_encoder.classes_,
        output_dict=True
    )

    report_df = pd.DataFrame(report_dict).transpose()

    if "accuracy" in report_df.index:
        report_df.loc["accuracy", ["precision", "recall"]] = [float("nan"), float("nan")]
        report_df.loc["accuracy", "support"] = len(decoded_true)

    report_df["support"] = pd.to_numeric(report_df["support"], errors="coerce").round()

    report_df.loc[report_df.index != "accuracy", "support"] = (
        report_df.loc[report_df.index != "accuracy", "support"].astype("Int64")
    )

    for col in ["precision", "recall", "f1-score"]:
        if col in report_df.columns:
            report_df[col] = pd.to_numeric(report_df[col], errors="coerce").round(2)

    report_df = report_df.astype(str)
    report_df.to_excel(f"classification_report_{code}_{tokenizer_model}_{seed}_{training_group}_{group_name}.xlsx")
    print("#####################################################")
