In [10]:
from datasets import load_dataset, load_metric
from pydantic.dataclasses import dataclass
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, set_seed
import numpy as np
from torch.utils.data import DataLoader
import pickle
from typing import Literal

import wandb
wandb.login()
%env WANDB_PROJECT=subset_active_learning_corrected

env: WANDB_PROJECT=subset_active_learning_corrected


In [11]:
sampling_sizes = (1000, 2000, 5000)

In [12]:
class ActiveLearner:
    def __init__(self, config):
        self.config = config

        set_seed(42)
        self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
        self.metric = load_metric("accuracy")

        self.sst2 = load_dataset("sst")
        self.valid_ds = self.preprocess(self.sst2["validation"])
        self.test_ds = self.preprocess(self.sst2["test"])
        self.train_data_indices = []

    def preprocess(self, data):
        data = data.rename_column("label", "scalar_label")
        data = data.map(lambda x: {"label": 0 if x["scalar_label"] < 0.5 else 1})

        def tokenize_func(examples):
            tokenized = self.tokenizer(
                examples["sentence"], padding="max_length", max_length=self.config.max_length, truncation=True
            )
            tokenized["labels"] = examples["label"]
            return tokenized

        ds = data.map(
            tokenize_func,
            remove_columns=data.column_names,
            batched=True,
        )
        ds.set_format(type="torch")
        return ds

    def compute_metrics(self, eval_pred):
        logits, labels = eval_pred
        predictions = np.argmax(logits, axis=-1)
        return self.metric.compute(predictions=predictions, references=labels)

    def sample_data(self, strategy, n_new_samples):
        if strategy == "random_sampling":
            print(f"config strategy is {config.strategy}")
            selected_indices = np.random.choice(len(self.sst2["train"]), replace=False, size=n_new_samples)
        return selected_indices

    def train(self):
        for i, sampling_size in enumerate(sampling_sizes):
            n_new_samples = sampling_size if i == 0 else sampling_sizes[i] - sampling_sizes[i - 1]
            print(f"Sampling {n_new_samples} new samples")
            self.step(n_new_samples)

    def step(self, n_new_samples):
        """Take an active learning step"""
        ########### set up data #########
        # sample new data
        sampled_data = self.sample_data(self.config.strategy, n_new_samples)
        # concatenate the sampled data with the original data
        self.train_data_indices.extend(sampled_data)
        train_data = self.sst2["train"].select(self.train_data_indices)
        debug_data = self.sst2["train"].select(self.train_data_indices[:8])

        self.train_ds = self.preprocess(train_data)
        self.valid_ds = self.preprocess(self.sst2["validation"])
        self.debug_ds = self.preprocess(debug_data)

        ########### set up training #########
        dir = f"./{self.config.strategy}/size_{len(self.train_data_indices)}" if not self.config.debug else "./debug"
        training_args = TrainingArguments(
            output_dir=dir,
            max_steps=self.config.max_steps if not self.config.debug else 640,
            evaluation_strategy="steps",
            report_to="wandb",
            run_name=f"{self.config.strategy}-size-{len(self.train_data_indices)}",
            eval_steps = 300,		
            learning_rate = 1e-5,
            adam_epsilon = 1e-6,
            warmup_ratio = 0.1,
            weight_decay=0.01
        )
        print(f"training_args: {training_args}")
        model = AutoModelForSequenceClassification.from_pretrained(self.config.model_name, num_labels=2)
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=self.train_ds if not self.config.debug else self.debug_ds,
            eval_dataset=self.valid_ds if not self.config.debug else self.debug_ds,
            compute_metrics=self.compute_metrics,
        )
        ######### train #######
        trainer.train()
        wandb.finish()

        ######## test ########
        outputs = trainer.predict(self.test_ds)
        with open(f"{dir}/test_set_evaluation_{len(self.train_data_indices)}.pkl", "wb") as f:
            pickle.dump(outputs, f)


In [16]:
@dataclass(frozen=True)
class Config:
    max_length: int = 66
    debug: bool = False
    model_name: str = "google/electra-small-discriminator"
    strategy: Literal["random_sampling", "uncertainty_sampling"] = "random_sampling"
    sampling_sizes: tuple = (1000, 2000, 3000, 4000)
    max_steps: int = 20000

config = Config(debug=True, sampling_sizes=(1000, 2000), strategy="uncertainty_sampling")

In [9]:
config.strategy

'ss'

In [None]:
active_learner = ActiveLearner(config)
active_learner.train()

In [25]:
with open(f"./debug/test_set_evaluation_1000.pkl", "rb") as f:
    loaded_outputs = pickle.load(f)

In [14]:
wandb.finish()