# Loading text block data into Argilla 

Loads some unlabelled text blocks into Argilla for a text classification task.

In [1]:
import sys

!{sys.executable} -m pip install argilla


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0[0m[39;49m -> [0m[32;49m23.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [2]:
import os
import random

from cpr_data_access.models import Dataset, BaseDocument
from dotenv import load_dotenv, find_dotenv
import argilla as rg
from tqdm.auto import tqdm
import spacy

load_dotenv(find_dotenv(), override=True)
nlp = spacy.load("en_core_web_sm")

  from .autonotebook import tqdm as notebook_tqdm


In [24]:
# Config

DATASET_NAME = "sector-text-classifier"
TEXT_BLOCKS_PER_DOCUMENT = 20

settings = rg.TextClassificationSettings(
    label_schema=[
        "energy",
        "transport",
        "industry",
        "buildings",
        "agriculture, forestry and other land use",
        "fisheries & aquaculture",
        "insurance & financial services",
        "water services",
        "health services",
        "tourism",
    ]
)

In [11]:
rg.init(
    workspace="gst",
    api_key=os.environ["ARGILLA_API_KEY"],
)

rg.configure_dataset(name=DATASET_NAME, settings=settings)

## 1. load in labelled text blocks from 'text block or sentence' task

In [23]:
SENTENCE_OR_TEXT_BLOCK_DATASET_NAME = "sectors-sentence-or-text-block"

sent_or_text_block_dataset = rg.load(SENTENCE_OR_TEXT_BLOCK_DATASET_NAME)
text_blocks_only = [
    d
    for d in sent_or_text_block_dataset
    if d.metadata["sentence_or_text_block"] == "text_block"
]
rg.log(
    text_blocks_only,
    name=DATASET_NAME,
)

BulkResponse(dataset='sector-text-classifier', processed=193, failed=0)

sent_or_text_block_dataset. load in sample of unlabelled text blocks

In [13]:
# User management is done at a workspace level

dataset = (
    Dataset(document_model=BaseDocument)
    .load_from_local(os.environ["DOCS_DIR_GST"])
    .filter_by_language("en")
)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1142/1142 [01:42<00:00, 11.11it/s]


In [25]:
records = []

for document in tqdm(dataset.documents):
    if document.text_blocks is None:
        print(f"Skipping {document.document_id} as no text blocks")
        continue

    doc_metadata = document.dict(exclude={"text_blocks", "page_metadata"})

    # Randomly sample a fixed number of text blocks per document
    if len(document.text_blocks) <= TEXT_BLOCKS_PER_DOCUMENT:
        blocks = document.text_blocks
    else:
        blocks = random.sample(document.text_blocks, TEXT_BLOCKS_PER_DOCUMENT)

    for block in blocks:
        block_metadata = block.dict(exclude={"text"})
        block_text = block.to_string().replace("\n", " ").replace("  ", " ")

        records.append(
            rg.TextClassificationRecord(
                text=block_text,
                multi_label=True,
                metadata=doc_metadata | block_metadata,
                id=f"{block.text_block_id}_{document.document_id}",
                vectors={"spacy": list(nlp(block_text).vector)},
            )
        )

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 929/929 [03:26<00:00,  4.49it/s]


In [26]:
dataset_metadata = {
    "documents": [doc.document_id for doc in dataset.documents],
}

rg.log(
    records,
    name=DATASET_NAME,
    metadata=dataset_metadata,
)

BulkResponse(dataset='sector-text-classifier', processed=16632, failed=0)