In [None]:
!pip install -r requirements.txt

In [1]:
import os
import datetime

import datasets
import numpy
import rubrix as rb

from pathlib import Path
from rubrix.listeners import listener

from active_learning_test.active_learner import (
    build_active_learner,
    convert_to_small_text_dataset,
    initialize_active_learner
)

In [2]:
def initialize_rubrix(initial_indices, trec_dataset, label_names):

    texts = [trec_dataset['train']['text'][i] for i in initial_indices]
    labels = [trec_dataset['train']['label-coarse'][i] for i in initial_indices]

    records = [
        rb.TextClassificationRecord(
            id=idx,
            text=text,
            annotation=label_names[labels[idx]],
            status='Validated'
        )
        for idx, text in enumerate(texts)
    ]
    rb.log(records, name='active-learning-test-batch-initial')
    
def log_next_batch(batch_idx, trec_dataset, queried_indices):
    texts = [trec_dataset['train']['text'][i] for i in queried_indices]
    records = [
        rb.TextClassificationRecord(
            id=f"{batch_idx}_{idx}",
            text=text,
            prediction=[
                (label, 0.0)
                for label in trec_dataset["train"].features["label-coarse"].names
            ],
            metadata={"batch_id": batch_idx},
        )
        for idx, text in enumerate(texts)
    ]
    print(f"Logging records for batch {batch_idx}")
    rb.log(records, name=f"active-learning-test-batch")


In [3]:
trec_dataset = datasets.load_dataset('trec')
label_names = trec_dataset['train'].features['label-coarse'].names

trec_dataset_st = convert_to_small_text_dataset(trec_dataset)
active_learner = build_active_learner(trec_dataset_st, len(label_names))

initial_indices = initialize_active_learner(active_learner, trec_dataset_st.y)
initialize_rubrix(initial_indices, trec_dataset, label_names)



  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

10 records logged to http://localhost:6900/datasets/rubrix/active-learning-test-batch-initial


In [4]:
label_name_to_idx = dict({
    name: i
    for i, name in enumerate(label_names)
})

batch_idx = 0
queried_indices = active_learner.query()
log_next_batch(batch_idx, trec_dataset, queried_indices)

Logging records for batch 0


  0%|          | 0/10 [00:00<?, ?it/s]

10 records logged to http://localhost:6900/datasets/rubrix/active-learning-test-batch


In [9]:
def condition(search):
    return search.total >= 10


@listener(
    dataset="active-learning-test-batch",
    query=f"status:Validated and NOT _exists_:metadata.processed",
    condition=condition,
    execution_interval_in_seconds=5,
)
def next_loop_step_2(records, ctx):
    new_labels = [label_name_to_idx[r.annotation] for r in records]
    active_learner.update(numpy.array(new_labels))
    
    batch_idx = 0
    for r in records:
        r.metadata["processed"] = True
        if r.metadata["batch_id"] > batch_idx:
            batch_idx = r.metadata["batch_id"]

    
    batch_idx += 1
    queried_indices = active_learner.query()
    log_next_batch(batch_idx, trec_dataset, queried_indices)
    
    rb.log(name=ctx.dataset, records=records, background=True)

In [10]:
next_loop_step_2.start()

Logging records for batch 1




  0%|          | 0/10 [00:00<?, ?it/s]

10 records logged to http://localhost:6900/datasets/rubrix/active-learning-test-batch


  0%|          | 0/10 [00:00<?, ?it/s]

10 records logged to http://localhost:6900/datasets/rubrix/active-learning-test-batch
Logging records for batch 2




  0%|          | 0/10 [00:00<?, ?it/s]

10 records logged to http://localhost:6900/datasets/rubrix/active-learning-test-batch


  0%|          | 0/10 [00:00<?, ?it/s]

10 records logged to http://localhost:6900/datasets/rubrix/active-learning-test-batch
Logging records for batch 3




  0%|          | 0/10 [00:00<?, ?it/s]

10 records logged to http://localhost:6900/datasets/rubrix/active-learning-test-batch


  0%|          | 0/10 [00:00<?, ?it/s]

10 records logged to http://localhost:6900/datasets/rubrix/active-learning-test-batch
Logging records for batch 4




  0%|          | 0/10 [00:00<?, ?it/s]

10 records logged to http://localhost:6900/datasets/rubrix/active-learning-test-batch


  0%|          | 0/10 [00:00<?, ?it/s]

10 records logged to http://localhost:6900/datasets/rubrix/active-learning-test-batch
