In [16]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import AutoPeftModelForCausalLM
from utils import load_yaml, load_args_promt
import torch, constants, json
import textwrap

In [2]:
settings = load_yaml('config.multitask.yml')

In [3]:
version = 1

settings['train']['batch_size'] = 8
settings['train']['max_steps'] = 256

In [4]:
def set_textual_information(promt, promt_settings, tweet):
    promt_settings['tweet'] = tweet
    return promt.format(**promt_settings)

def fill_values_from_row(row, promt_template, promt_settings):
    return set_textual_information(
        promt_template, promt_settings[row['task']], row['tweet']
    )

promt_settings = {}

if isinstance(settings["k-shot"], dict):
    for task, _args_task in settings['k-shot'].items():
        promt_settings[task] = load_args_promt(settings, _args_task)
else:
    with open(f'promts/{settings["k-shot"]}') as buffer:
        promt_settings[
            settings['tune_name']
        ] = json.load(buffer)[settings['k-selected']]
    
question = constants.QUESTION
promt_template = getattr(constants, settings['promt_template'])
promt_template = promt_template + " " + question

print(textwrap.fill(promt_template))

You are a {model}. {task}: {categories_description} The tweet text is
delimited with triple backticks.  Generate a JSON with the following
keys: 'prediction' containing a single predicted category, either
{category_list}, and 'explanation' containing the reason (limited to
100 characters) for the categorization decision. Tweet text:
'''{tweet}'''


In [15]:
row = {
    'tweet': 'Governor General Quentin Bryce has visited the Blue Mtns today, thanking firefighters #NSWRFS #nswfires http://t.co/gjt1bJPhHR',
    'task': 'humanitarian', 
    'explanation': '',
    # 'label': 'Sympathy and support',
}

_input_sample = fill_values_from_row(row, promt_template, promt_settings)
print(textwrap.fill(_input_sample))

You are a social media message classifier. Your task is to categorize
a tweet posted during crisis events as one of the following
humanitarian categories: - Affected individuals: information about
deaths, injuries, missing, trapped, found or displaced people,
including personal updates about oneself, family, or others. -
Infrastructure and utilities: information about buildings, roads,
utilities/services that are damaged, interrupted, restored or
operational. - Donations and volunteering: information about needs,
requests, queries or offers of money, blood, shelter, supplies (e.g.,
food, water, clothing, medical supplies) and/or services by volunteers
issued or lifted, guidance and tips. - Sympathy and support: thoughts,
prayers, gratitude, sadness, etc. - Other useful information:
information NOT covered by any of the above categories. It considers
information about fire line/emergency location, flood level,
earthquake magnitude, weather, wind, visibility; smoke, ash;
adjunctive and m

In [6]:
tokenizer = AutoTokenizer.from_pretrained(
    settings['model_id'], token=settings['hub_id'],
    model_max_length=settings['model_max_length'],
    padding_side=settings['padding_side'],
)
tokenizer.pad_token = tokenizer.eos_token

In [7]:
_steps = settings['train'].get('max_steps', -1)

_name = settings['model_name'] +'__'+ settings['tune_name']
_filename = f"trained_models/{_name}-v{version}_{_steps}"

model = AutoPeftModelForCausalLM.from_pretrained(
    _filename, device_map="auto", torch_dtype=torch.bfloat16
)

model = model.merge_and_unload()

Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:08<00:00,  4.21s/it]


In [13]:
_inputs = tokenizer(
    [_input_sample], return_tensors="pt",
    truncation=True, padding=True
)

_results = model.generate(
    input_ids=_inputs["input_ids"].to('cuda'), 
    attention_mask=_inputs["attention_mask"].to('cuda'), 
    pad_token_id=tokenizer.eos_token_id,
    max_new_tokens=256, temperature=settings['temperature']
)

reponse = tokenizer.decode(
    _results[0], skip_special_tokens=True
)

print(textwrap.fill(reponse.split('[/]')[1]))

 {"prediction": "Affected individuals", "explanation": "Tweet mentions
victims needing help and food."}
