In [4]:
import typing
import os

import pandas
import datasets
import trl
import peft

import cltrier_lib

In [5]:
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [6]:
MODEL_SLUG: str = "meta-llama/Llama-3.2-3B-Instruct"
UPLOAD_SLUG: str = f"{MODEL_SLUG.split('/')[-1]}-OSN-posts"

RAW_DATASET: str = "../data/interim/twitter.german.dataset.enriched.csv"

SFT_ARGS = trl.SFTConfig(
    num_train_epochs=8,
    per_device_train_batch_size=4,
    packing=True, 
    save_strategy="no",
    output_dir=f"../models/{UPLOAD_SLUG}",
    logging_steps=50,
    push_to_hub=True,
    push_to_hub_model_id=UPLOAD_SLUG
)

PEFT_ARGS = peft.LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    task_type="CAUSAL_LM",
)



In [7]:
dataset: typing.List[typing.Dict] = [
    cltrier_lib.inference.schemas.Chat(messages=[
        cltrier_lib.inference.schemas.Message(role="system", content=f"You are a social media user with a political {row['leaning_post']} leaning. Post a Tweet about the following topic:"),
        cltrier_lib.inference.schemas.Message(role="user", content=row["topics_post"]),
        cltrier_lib.inference.schemas.Message(role="assistant", content=row["text_post"])
    ]).model_dump()
    for _, row in pandas.read_csv(RAW_DATASET, index_col=0).iterrows()
]
dataset[:3]

[{'messages': [{'role': 'system',
    'content': 'You are a social media user with a political neutral leaning. Post a Tweet about the following topic:'},
   {'role': 'user', 'content': 'Klima, eFuels, Verbrenner'},
   {'role': 'assistant',
    'content': 'Nicht der #Verbrenner schadet dem #Klima, sondern der fossile Sprit, mit dem er fährt. Wir haben diese Woche den Weg für klimaneutrale #eFuels freigemacht. Damit könnten die mehr als 45 Mio. Diesel- und Benzin-Fahrzeuge auf unseren Straßen in Zukunft klimaneutral unterwegs sein.'}]},
 {'messages': [{'role': 'system',
    'content': 'You are a social media user with a political right leaning. Post a Tweet about the following topic:'},
   {'role': 'user', 'content': 'Kernkraft, Altparteien, AfD'},
   {'role': 'assistant',
    'content': 'Wo waren die ganzen plötzlichen #Kernkraftbefürworter in #Altparteien, Verbänden &amp; Medien in den letzten Jahren? Warum stimmte die #umfaller: #fdp bis zuletzt im Bundestag gegen Laufzeitverlängerun

In [5]:
trainer = trl.SFTTrainer(
    MODEL_SLUG,
    args=SFT_ARGS,
    train_dataset=datasets.Dataset.from_pandas(pandas.DataFrame(data=dataset)),
    peft_config=PEFT_ARGS,
)

trainer.train()

trainer.save_model(SFT_ARGS.output_dir)

if SFT_ARGS.push_to_hub:
    trainer.push_to_hub()



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Generating train split: 0 examples [00:00, ? examples/s]

Step,Training Loss
50,2.3106
100,1.8385
150,1.6834
200,1.5619
250,1.5118
300,1.4448
350,1.3697
400,1.3044
450,1.2545
500,1.2342


Upload 3 LFS files:   0%|          | 0/3 [00:00<?, ?it/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/9.19M [00:00<?, ?B/s]

training_args.bin:   0%|          | 0.00/5.56k [00:00<?, ?B/s]

No files have been modified since last commit. Skipping to prevent empty commit.
