In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model
from peft import prepare_model_for_int8_training
import bitsandbytes as bnb

model = AutoModelForCausalLM.from_pretrained(
    "NYTK/PULI-GPT-3SX",
    load_in_8bit=True,
    device_map="auto",
)

model = prepare_model_for_int8_training(model)
config = LoraConfig(
    r=8, lora_alpha=16, target_modules=["query", "value"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM"
)

model = get_peft_model(model, config)

tokenizer = AutoTokenizer.from_pretrained("NYTK/PULI-GPT-3SX")
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

In [None]:
# to load a previous model for further finetuning

from peft import set_peft_model_state_dict

adapters_weights = torch.load('kmonitor-entity-v1-20k//adapter_model.bin')
model = set_peft_model_state_dict(model, adapters_weights)

In [None]:
from datasets import load_from_disk

data = load_from_disk('kmonitor_entity')

In [None]:
data

In [None]:
def trunc_to(descr, n):
    if '.' in descr[int(n/2):n]:
        return descr[:descr[:n].rfind('.')+1]
    elif ' ' in descr[int(n/2):n]:
        return descr[:descr[:n].rfind(' ')+1]
    else:
        return descr[:n]

data = data.map(lambda samples: tokenizer('[entitás]\n' + trunc_to(samples['text'], 2500) + '\n\n###\n\nszemély: '+ samples['person'] +'\ntéma: ' +('korrupció' if samples['is_corrupt'] else 'egyéb')+'\n'), batched=False)

In [None]:
data['train'] = data['train'].remove_columns(['text', 'person', 'is_corrupt'])

data

In [None]:
strain = data['train'].select(range(0, 20000))

In [None]:
from transformers import Trainer


class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        if hasattr(self, 'n_step'):
            self.n_step += 1
        else:
            self.n_step = 1
        if self.n_step % 1000 == 0:
            model.save_pretrained("./kmonitor-gpt-fixed-loss-checkpoint-"+str(self.n_step))
        label = tokenizer.decode([inputs.get("input_ids")[0][-2]]).strip()
        output = model(**inputs)
        l = output['loss']

        outputs = model.generate(input_ids=inputs.get("input_ids")[:, :-2], attention_mask=inputs.get("attention_mask")[:, :-2], max_new_tokens=1)
        result = tokenizer.decode([outputs[0][-1]]).strip()
        loss = 0.5
        # TODO better loss calculation
        if label.strip() in result.strip():
            loss = 0.0
            l /= 2
        elif label.strip() not in ['egyéb', 'korrupció']:
            loss = 1.0

        l.backward()
        return torch.tensor(loss, device='cuda:0', requires_grad=True)

In [None]:
import transformers

trainer = CustomTrainer(
    model=model,
    train_dataset=strain,
    args=transformers.TrainingArguments(
        report_to='wandb',
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        num_train_epochs=1,
        gradient_accumulation_steps=1,
        warmup_steps=300,
        max_steps=20000,
        learning_rate=9e-3,
        fp16=True,
        weight_decay=0.0,
        logging_steps=10,
        #callbacks=[SavePeftModelCallback],
        output_dir="outputs",
        resume_from_checkpoint=True,
        save_total_limit=1,
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

model.config.use_cache = False
trainer.train()

In [None]:
model.save_pretrained("./kmonitor-entity-v1-20k/")