In [None]:
import argilla as rg
url = ""
api = "owner.apikey"
rg.init(
    api_url=url,
    api_key=api
)

In [None]:
import calendar
from datetime import datetime
import re
import time

import requests
from transformers import pipeline
import tqdm

from unstructured.partition.html import partition_html
from unstructured.documents.elements import NarrativeText, ListItem
from unstructured.staging.argilla import stage_for_argilla

In [None]:
import nltk
nltk.download("averaged_perceptron_tagger")

In [None]:
ISW_BASE_URL = "https://www.understandingwar.org/backgrounder/russian-offensive-campaign-assessment"

def datetime_to_url(dt):
    month = dt.strftime("%B").lower()
    return f"{ISW_BASE_URL}-{month}-{dt.day}"

urls = []
year = 2022

for month in range(3, 13):
    _, last_day = calendar.monthrange(year, month)
    for day in range(1, last_day + 1):
        dt = datetime(year, month, day)
        urls.append(datetime_to_url(dt))

In [None]:
def url_to_elements(url):
    r = requests.get(url)
    if r.status_code != 200:
        return None

    elements = partition_html(text=r.text)
    return elements

In [None]:
def _find_key_takeaways_idx(elements):
    for idx, element in enumerate(elements):
        if element.text == "Key Takeaways":
            return idx

def get_key_takeaways(elements):
    key_takeaways_idx = _find_key_takeaways_idx(elements)
    if not key_takeaways_idx:
        return None

    takeaways = []
    for element in elements[key_takeaways_idx + 1:]:
        if not isinstance(element, ListItem):
            break
        takeaways.append(element)

    takeaway_text = " ".join([el.text for el in takeaways])
    return NarrativeText(text=takeaway_text)

In [None]:
urls[200]

In [None]:
elements = url_to_elements(urls[200])

In [None]:
from rich import print
print(get_key_takeaways(elements))

In [None]:
# Show a sample of narrative text
def get_narrative(elements):
    narrative_text = ""
    for element in elements:
        if isinstance(element, NarrativeText) and len(element.text) > 500:
            # NOTE: Removes citations like [3] from the text
            element_text = re.sub("\[\d{1,3}\]", "", element.text)
            narrative_text += f"\n\n{element_text}"

    return NarrativeText(text=narrative_text.strip())

print(get_narrative(elements).text[0:2000])

In [None]:
inputs = []
annotations = []

for url in tqdm.tqdm(urls):

    elements = url_to_elements(url)

    if url is None or not elements:
        continue

    text = get_narrative(elements)

    annotation = get_key_takeaways(elements)

    if text and annotation:
        inputs.append(text)
        annotations.append(annotation.text)

    # NOTE: Sleeping to reduce the volume of requests to ISW
    time.sleep(1)

In [None]:
dataset_rg =  stage_for_argilla(inputs, 'text2text', annotation=annotations)

In [None]:
dataset_rg.to_pandas().head(3)

In [None]:
rg.log(dataset_rg, name='isw_summarise', workspace='hfgilla')

In [None]:
training_data = rg.load(name="isw_summarise",workspace='hfgilla').to_datasets()

In [None]:
training_data

In [None]:
from transformers import AutoTokenizer
model_checkpoint = "t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

max_input_length = 1024
max_target_length = 128

def preprocess_function(examples):
    inputs = [doc for doc in examples["text"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["annotation"], max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_datasets = training_data.map(preprocess_function, batched=True)

In [None]:
from transformers import (
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer
)

model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

In [None]:
batch_size = 16
model_name = model_checkpoint.split("/")[-1]
args = Seq2SeqTrainingArguments(
    "t5-small-isw-summaries",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=1,
    predict_with_generate=True,
    fp16=False,
    push_to_hub=False,
)

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [None]:
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets,
    eval_dataset=tokenized_datasets,
    data_collator=data_collator,
    tokenizer=tokenizer,
)

trainer.train()