# Setup

In [None]:
# @title System update
%%capture
!apt update && apt upgrade -y
!uv pip install --upgrade pip

In [None]:
from google.colab import userdata
import os
os.environ['GIT_TOKEN'] = userdata.get('git_token')
os.environ['USER_NAME'] = userdata.get('user_name')
os.environ['USER_MAIL'] = userdata.get('user_mail')

In [None]:
%%bash
git config --global user.name "$USER_NAME"
git config --global user.email "$USER_MAIL"
git clone https://$GIT_TOKEN@github.com/bjoernbethge/ethics-model.git


In [None]:
%cd ethics-model

In [None]:
%%bash

uv sync --extra train
chmod +x .venv/bin/activate

In [None]:
%%bash
source .venv/bin/activate

In [None]:
!.venv/bin/activate

In [None]:
# @title Imports

import os
import torch
import random
import logging
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from datasets import load_dataset
import nlpaug.augmenter.word as naw

from ethics_model.model import EthicsModel
from ethics_model.data import MultiTaskDataset
from ethics_model.training import train






# Configuration

In [None]:
# @title  Logging & directories
from google.colab import drive
drive.mount('/content/drive')

logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
logger = logging.getLogger(__name__)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CHECKPOINT_PATH = "checkpoints/best_ethics_model.pt"
TENSORBOARD_LOGDIR = "runs/ethics_llm_train"
os.makedirs(os.path.dirname(CHECKPOINT_PATH), exist_ok=True)

In [None]:
# @title  LLM & Tokenizer
huggingface_model = 'unsloth/gemma-3-4b-it-unsloth-bnb-4bit' # @param {type:"string"}
bnb_config = BitsAndBytesConfig(load_in_4bit=True)
tokenizer = AutoTokenizer.from_pretrained(huggingface_llm)
llm = AutoModelForCausalLM.from_pretrained(
    huggingface_llm,
    quantization_config=bnb_config,
    device_map="auto"
)
llm.eval()


In [None]:
# @title Model Configuration

n_layers = 2  # @param {type:"integer", min:1, max:12, step:1}
n_heads = 8  # @param {type:"integer", min:1, max:16, step:1}
max_seq_length = 128  # @param {type:"integer", min:64, max:512, step:64}
activation = 'gelu'  # @param ["gelu", "relu", "tanh"]
use_gnn = False  # @param {type:"boolean"}

model_config = {
    'input_dim': llm.config.hidden_size,
    'd_model': llm.config.hidden_size,
    'n_layers': n_layers,
    'n_heads': n_heads,
    'vocab_size': tokenizer.vocab_size,
    'max_seq_length': max_seq_length,
    'activation': activation,
    'use_gnn': use_gnn
}

model = EthicsModel(**model_config).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
criterion = torch.nn.BCELoss()

# Data

In [None]:
# @title  Augmentation

aug = naw.SynonymAug(aug_src='wordnet')
def synonym_augment(text):
    try:
        return aug.augment(text)
    except Exception:
        return text

In [None]:
# @title  Preparation

ds = load_dataset("flozi00/Fineweb2-German-Eduscore-4andMore", split="train[:1000]")
texts = ds["text"]
ethics_labels = [float(x) for x in ds["eduscore"]]
manipulation_labels = [float(x) for x in ds["manipulation_score"]] if "manipulation_score" in ds.column_names else ethics_labels

dataset = MultiTaskDataset(
    texts, ethics_labels, manipulation_labels, tokenizer,
    augment=True, synonym_augment=synonym_augment
)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

# Run

In [None]:
# @title  Training
writer = SummaryWriter(log_dir=TENSORBOARD_LOGDIR)
model_trained = train(model, llm, dataloader, optimizer, criterion, writer, DEVICE)
writer.close()

In [None]:
# @title  Evaluation
def evaluate(model, llm, dataloader, tokenizer, device):
    model.eval()
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            ethics_label = batch['ethics_label'].to(device)
            manipulation_label = batch['manipulation_label'].to(device)
            llm_outputs = llm.model.transformer(input_ids) if hasattr(llm, 'model') else llm.transformer(input_ids)
            hidden_states = llm_outputs.last_hidden_state
            outputs = model(embeddings=hidden_states, attention_mask=attention_mask)
            ethics_score = outputs['ethics_score']
            manipulation_score = outputs['manipulation_score']
            print(f"Text: {tokenizer.batch_decode(input_ids, skip_special_tokens=True)}")
            print(f"Ethics Score: {ethics_score.squeeze(-1).cpu().numpy()} | Label: {ethics_label.squeeze(-1).cpu().numpy()}")
            print(f"Manipulation Score: {manipulation_score.squeeze(-1).cpu().numpy()} | Label: {manipulation_label.squeeze(-1).cpu().numpy()}")

evaluate(model_trained, llm, dataloader, tokenizer, DEVICE)