In [None]:
import time
import psutil
import torch
from datasets import load_dataset
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments, DataCollatorWithPadding
import pandas as pd
from ace_tools import display_dataframe_to_user

In [None]:
model_name = "prajjwal1/bert-tiny"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)

In [None]:
dataset = load_dataset('csv', data_files={
    'train': 'sst2_train.csv',
    'validation': 'sst2_validation.csv'
})

In [None]:
def tokenize_function(example):
    return tokenizer(example['sentence'], truncation=True)
tokenized_datasets = dataset.map(tokenize_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [None]:
subset_sizes = [25, 50, 100]
results = []
for size in subset_sizes:
    subset = tokenized_datasets['train'].shuffle(seed=42).select(range(size))
    
    training_args = TrainingArguments(
        output_dir=f'./results_subset_{size}',
        num_train_epochs=3,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        logging_steps=10,
        evaluation_strategy="epoch",
        save_strategy="no",
        disable_tqdm=True
    )
    
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=subset,
        eval_dataset=tokenized_datasets['validation'],
        tokenizer=tokenizer,
        data_collator=data_collator
    )
    
    process = psutil.Process()
    mem_before = process.memory_info().rss / (1024 ** 2)
    start_time = time.time()
    
    train_output = trainer.train()
    eval_output = trainer.evaluate()
    
    time_elapsed = time.time() - start_time
    mem_after = process.memory_info().rss / (1024 ** 2)
    
    results.append({
        'subset_size': size,
        'train_time_sec': round(time_elapsed, 2),
        'memory_before_mb': round(mem_before, 2),
        'memory_after_mb': round(mem_after, 2),
        'eval_accuracy': round(eval_output.get('eval_accuracy', 0), 4)
    })


df_results = pd.DataFrame(results)
display_dataframe_to_user("Subset Selection Baseline Results", df_results)
