In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.optim import AdamW
from torch.nn.functional import normalize
from torch.amp import GradScaler, autocast
from sklearn.metrics import classification_report
from tqdm import tqdm, trange
import torch
import pandas as pd
import numpy as np
import random
import platform
import sys
import sklearn
import transformers
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

seed = 677
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

tokenizer_model = "roberta-base"
training_group = "whole"  # females, males, whole
coreset_fraction = 0.25
epochs = 2
batch_size = 32
code = "KCG"


training_df = pd.read_csv("PAN16_training_df.csv")
validation_df = pd.read_csv("PAN16_validation_df.csv")
test_df = pd.read_csv("PAN16_test_df.csv")


X_train = training_df.drop(columns='task_label')
y_train = training_df['task_label']
X_valid = validation_df.drop(columns='task_label')
y_valid = validation_df['task_label']
X_test = test_df.drop(columns='task_label')
y_test = test_df['task_label']

assert len(X_train) == len(y_train)
assert len(X_valid) == len(y_valid)
assert len(X_test) == len(y_test)
print("Split integrity verified")


X_train_males = X_train[X_train['gender'] == 'male'].copy()
y_train_males = y_train[X_train['gender'] == 'male']
X_train_females = X_train[X_train['gender'] == 'female'].copy()
y_train_females = y_train[X_train['gender'] == 'female']

X_valid_males = X_valid[X_valid['gender'] == 'male'].copy()
y_valid_males = y_valid[X_valid['gender'] == 'male']
X_valid_females = X_valid[X_valid['gender'] == 'female'].copy()
y_valid_females = y_valid[X_valid['gender'] == 'female']

X_test_males = X_test[X_test['gender'] == 'male'].copy()
y_test_males = y_test[X_test['gender'] == 'male']
X_test_females = X_test[X_test['gender'] == 'female'].copy()
y_test_females = y_test[X_test['gender'] == 'female']

assert X_train_males['gender'].nunique() == 1
assert X_train_females['gender'].nunique() == 1
assert X_valid_males['gender'].nunique() == 1
assert X_valid_females['gender'].nunique() == 1
assert X_test_males['gender'].nunique() == 1
assert X_test_females['gender'].nunique() == 1
print("Gender splits confirmed")


if training_group == "females":
    X_train_group, y_train_group = X_train_females['text'], y_train_females
    X_valid_group, y_valid_group = X_valid_females['text'], y_valid_females
elif training_group == "males":
    X_train_group, y_train_group = X_train_males['text'], y_train_males
    X_valid_group, y_valid_group = X_valid_males['text'], y_valid_males
else:
    X_train_group, y_train_group = X_train['text'], y_train
    X_valid_group, y_valid_group = X_valid['text'], y_valid

print(f"[INFO] Training group: {training_group} — Number of available training examples: {len(X_train_group)}")


tokenizer = AutoTokenizer.from_pretrained(tokenizer_model)

def tokenize_function(texts):
    return tokenizer(list(texts), padding="max_length", truncation=True, max_length=64, return_tensors="pt")

class PAN16DATASET(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        return {
            'input_ids': self.encodings['input_ids'][idx],
            'attention_mask': self.encodings['attention_mask'][idx],
            'labels': torch.tensor(self.labels.iloc[idx], dtype=torch.long)
        }

    def __len__(self):
        return len(self.labels)


@torch.no_grad()
def compute_embeddings(texts, model, tokenizer, device):
    model.eval()
    embeddings = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Computing Embeddings"):
        batch = texts.iloc[i:i + batch_size]
        encodings = tokenizer(list(batch), padding="max_length", truncation=True, max_length=64, return_tensors="pt")
        input_ids = encodings['input_ids'].to(device)
        attention_mask = encodings['attention_mask'].to(device)

        with autocast(device_type='cuda'):
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, return_dict=True)
            cls_embeddings = outputs.hidden_states[-1][:, 0, :]
            cls_embeddings = normalize(cls_embeddings, dim=1)
            embeddings.append(cls_embeddings.cpu())
    return torch.cat(embeddings, dim=0).to(device)


def k_center_greedy(embeddings, k):
    selected = [torch.randint(len(embeddings), (1,), device=embeddings.device).item()]
    distances = torch.cdist(embeddings, embeddings[selected]).min(dim=1).values

    for _ in trange(1, k, desc="K-Greedy Selection"):
        idx = torch.argmax(distances).item()
        selected.append(idx)
        new_dist = torch.cdist(embeddings, embeddings[[idx]]).squeeze()
        distances = torch.minimum(distances, new_dist)
    return selected


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embedding_model = AutoModelForSequenceClassification.from_pretrained(tokenizer_model, num_labels=2).to(device)
for param in embedding_model.parameters():
    param.requires_grad = False

embeddings = compute_embeddings(X_train_group, embedding_model, tokenizer, device)
k = int(coreset_fraction * len(X_train_group))
print(f"INFO: Number of selected training examples: {k}")
selected_indices = k_center_greedy(embeddings, k)
X_train_core = X_train_group.iloc[selected_indices]
y_train_core = y_train_group.iloc[selected_indices]

# Visualization
embeddings_np = embeddings.cpu().numpy()
selected_positions = np.array(selected_indices)
assert selected_positions.max() < len(embeddings), "Index out of bounds in selection"
selected_mask = np.zeros(len(embeddings), dtype=bool)
selected_mask[selected_positions] = True
print(f"Selected points: {selected_mask.sum()} / {len(selected_mask)} total")

tsne = TSNE(n_components=2, random_state=seed, perplexity=30)
embeddings_2d = tsne.fit_transform(embeddings_np)
plt.figure(figsize=(14, 12))

plt.scatter(
    embeddings_2d[~selected_mask, 0],
    embeddings_2d[~selected_mask, 1],
    c='gray', alpha=0.5, label='Unselected',
    s=10, edgecolors='none'
)
plt.scatter(
    embeddings_2d[selected_mask, 0],
    embeddings_2d[selected_mask, 1],
    c='red', alpha=0.7, label='Selected',
    s=10, edgecolors='none'
)
plt.title("t-SNE Visualization of Embedding Space (K-Center Greedy)")
plt.xlabel("t-SNE Component 1")
plt.ylabel("t-SNE Component 2")
plt.legend()
plt.grid(True)
plt.tight_layout()
# plt.savefig(f"DS_{training_group}_{tokenizer_model}_{seed}.png", dpi=300) # png
plt.savefig(f"DS_{training_group}_{tokenizer_model}_{seed}.pdf", bbox_inches='tight') # pdf


index_file = f"DS_indices_{training_group}_{tokenizer_model}_{seed}.csv"
index_df = pd.DataFrame({"selected_indices": sorted(X_train_core.index)})
index_df.to_csv(index_file, index=False)
print(f"Saved coreset indices to: {index_file}")


train_dataset = PAN16DATASET(tokenize_function(X_train_core), y_train_core)
valid_dataset = PAN16DATASET(tokenize_function(X_valid_group), y_valid_group)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size)

model = AutoModelForSequenceClassification.from_pretrained(tokenizer_model, num_labels=2).to(device)
if tokenizer_model == "bert-base-uncased":
    lr = 2e-5
elif tokenizer_model == "roberta-base":
    lr = 2e-5
elif tokenizer_model == "distilroberta-base":
    lr = 5e-5

optimizer = AdamW(model.parameters(), lr=lr)
scaler = GradScaler()

for epoch in range(epochs):
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader, desc=f"Training Epoch {epoch+1}"):
        input_ids = batch['input_ids'].to(device)
        mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        with autocast(device_type='cuda'):
            outputs = model(input_ids, attention_mask=mask, labels=labels)
            loss = outputs.loss
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item()
    print(f"Epoch {epoch+1} - Loss: {total_loss / len(train_loader):.4f}")


    model.eval()
    preds, true_labels = [], []
    with torch.no_grad():
        for batch in tqdm(valid_loader, desc="Validation"):
            input_ids = batch['input_ids'].to(device)
            mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            logits = model(input_ids, attention_mask=mask).logits
            pred = torch.argmax(logits, dim=1)
            preds.extend(pred.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
    print(classification_report(true_labels, preds, target_names=["No Mention (0)", "Mention (1)"]))


test_sets = {
    "whole": (X_test['text'], y_test),
    "males": (X_test_males['text'], y_test_males),
    "females": (X_test_females['text'], y_test_females)
}

print(f"\n*** Used code: {code}. Training group: {training_group}. Model: {tokenizer_model}. Seed: {seed} ***")
for name, (Xg, yg) in test_sets.items():
    print(f"\n--- Testing on {name.upper()} ---")
    test_data = PAN16DATASET(tokenize_function(Xg), yg)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=32)
    model.eval()
    preds, labels_all = [], []
    with torch.no_grad():
        for batch in tqdm(test_loader, desc=f"Testing {name}"):
            input_ids = batch['input_ids'].to(device)
            mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            logits = model(input_ids, attention_mask=mask).logits
            pred = torch.argmax(logits, dim=1)
            preds.extend(pred.cpu().numpy())
            labels_all.extend(labels.cpu().numpy())
    print(classification_report(labels_all, preds, target_names=["No Mention (0)", "Mention (1)"]))
    print("#####################################################")
