In [1]:
import os
from dotenv import load_dotenv

# Set the device to physical GPU 3
# Physics server
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

env_path = "./config/.env"
load_dotenv(dotenv_path=env_path)


True

In [2]:
device_map = {"": 0}
gpu_device   = 'cuda:0'

In [3]:
import torch

num_gpus = torch.cuda.device_count()
print(f"Found {num_gpus} GPUs available to PyTorch:")
print("-" * 40)

for i in range(num_gpus):
    name = torch.cuda.get_device_name(i)
    mem = torch.cuda.get_device_properties(i).total_memory / 1024**3
    print(f"Device Index {i}: {name} ({mem:.2f} GB)")

print("-" * 40)

Found 1 GPUs available to PyTorch:
----------------------------------------
Device Index 0: NVIDIA A100-SXM4-80GB (79.25 GB)
----------------------------------------


In [4]:
import torch
import torch.nn as nn

import pandas as pd
import numpy as np

from datasets import Dataset, load_dataset, DatasetDict 

from peft import (LoraConfig, 
                  PeftModel, 
                  prepare_model_for_kbit_training, 
                  get_peft_model,
                  PeftModelForSequenceClassification,
                  PeftConfig)

from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import (
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    AutoModelForSequenceClassification, 
    TrainingArguments, 
    Trainer,
    EarlyStoppingCallback,
    DataCollatorWithPadding,
    AutoModelForCausalLM)

import bitsandbytes as bnb
import evaluate

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
import os

# Configure the NTHU proxy directly in Python using the IP address
proxy_url = "http://140.114.63.4:3128"

os.environ['http_proxy'] = proxy_url
os.environ['https_proxy'] = proxy_url
os.environ['HTTP_PROXY'] = proxy_url
os.environ['HTTPS_PROXY'] = proxy_url

print("Proxy configured via IP address.")

Proxy configured via IP address.


In [6]:
from huggingface_hub import login
hf_token = os.getenv('HF_TOKEN')
login(hf_token)

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


In [18]:
dataset_toxic = load_dataset("thesofakillers/jigsaw-toxic-comment-classification-challenge")
dataset_toxic = dataset_toxic['train']
dataset_toxic = dataset_toxic.train_test_split(test_size=0.25,seed=42,)

test_valid  = dataset_toxic['test'].train_test_split(test_size=0.5)

dataset_toxic = DatasetDict({
    'train': dataset_toxic['train'].select(range(1000)),
    'valid': test_valid['train'].select(range(100)),
    'test': test_valid['test'].select(range(100))})

dataset_toxic

DatasetDict({
    train: Dataset({
        features: ['id', 'comment_text', 'toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'],
        num_rows: 1000
    })
    valid: Dataset({
        features: ['id', 'comment_text', 'toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'],
        num_rows: 100
    })
    test: Dataset({
        features: ['id', 'comment_text', 'toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'],
        num_rows: 100
    })
})

In [19]:
df = pd.DataFrame(dataset_toxic['train'])
df.head()

Unnamed: 0,id,comment_text,toxic,severe_toxic,obscene,threat,insult,identity_hate
0,5ccc42b72dd3f5db,Hello b.i.t.c.h \nHello little s.l.u.t. Do you...,1,1,1,0,1,1
1,aa6c40c32f92d39d,== Bad faith deletion of new article on Shefa ...,0,0,0,0,0,0
2,8d113a1596e1d4d3,If you don't know what you're doing ... \n\nKe...,0,0,0,0,0,0
3,b93cd46037e2c08f,"""\n\nYou gave it away?! Why in the Sam Hill wo...",0,0,0,0,0,0
4,9a135e8160fff88c,"""\n\n Popular culture \n\nA few months ago I p...",0,0,0,0,0,0


In [20]:
hugging_face_model_id = "google/gemma-3-1b-it" # google/gemma-3-4b-it

from transformers import AutoTokenizer 
tokenizer = AutoTokenizer.from_pretrained(hugging_face_model_id,
                                          padding_side='right',
                                          device_map=device_map,
                                          add_bos=True)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

class2id = {'toxic':0,'severe_toxic':1,'obscene':2,'threat':3,'insult':4,'identity_hate':5}
id2class = {v: k for k, v in class2id.items()}

In [21]:
def preprocess_function(sample):
    labels = []
    for class_ in class2id.keys():
        labels.append(sample[class_])

    sample = tokenizer(sample['comment_text'], truncation=False)
    sample['labels'] = labels
    return sample


dataset_toxic_tokenized = dataset_toxic.map(preprocess_function)
dataset_toxic_tokenized = dataset_toxic_tokenized.select_columns(['input_ids','attention_mask','labels'])
dataset_toxic_tokenized

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map: 100%|██████████| 1000/1000 [00:00<00:00, 2616.16 examples/s]
Map: 100%|██████████| 100/100 [00:00<00:00, 2231.92 examples/s]
Map: 100%|██████████| 100/100 [00:00<00:00, 2095.09 examples/s]


DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 1000
    })
    valid: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 100
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 100
    })
})

In [22]:
sample_index = 3 # Choose any sample index

sample_input_ids = dataset_toxic_tokenized['train']['input_ids'][sample_index]
sample_labels = dataset_toxic_tokenized['train']['labels'][sample_index]

print('Input data for model:')
print(f"IDs   : {sample_input_ids}")
print(f"Labels: {sample_labels}\n")

print('Input data decoded:')
print(f"Tokens: {tokenizer.decode(sample_input_ids)}")
# Reconstruct the label dictionary for this sample
decoded_labels = {id2class[i]: sample_labels[i] for i in range(len(sample_labels))}
print(f"Label dictionary: {decoded_labels}")

Input data for model:
IDs   : [2, 236775, 108, 3048, 5877, 625, 3121, 26052, 8922, 528, 506, 6687, 10892, 1093, 611, 776, 600, 236881, 1452, 1537, 236789, 236745, 12141, 1003, 625, 236761, 38403, 236764, 141814, 3588, 855, 506, 76001, 49165, 564, 17231, 44395, 236764, 532, 564, 2752, 2506, 8379, 2802, 528, 1041, 14064, 236761, 1593, 625, 236789, 236751, 45320, 1041, 12866, 600, 692, 2752, 2506, 531, 1441, 1546, 1032, 568, 1452, 611, 236789, 560, 3931, 8672, 2311, 78345, 236881, 568, 3524, 735, 611, 1010, 528, 625, 672, 3697, 990, 236881, 5315, 990, 611, 236789, 560, 6950, 625, 2907, 2900, 573, 236881, 138, 240478, 201028, 138, 236775]
Labels: [0, 0, 0, 0, 0, 0]

Input data decoded:
Tokens: <bos>"

You gave it away?! Why in the Sam Hill would you do that? And don't worry about it. Actually, Brawl came out the VERY DAY I MOVED, and I never got internet access in my apartment. So it's kinda my fault that we never got to play each other ( And you've started college too huh? (Or have you be

In [23]:
from transformers import DataCollatorWithPadding
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [24]:
sample_batch_ids           = dataset_toxic_tokenized['train']['input_ids'][0:3]
sample_batch_ids_collator  = data_collator(dataset_toxic_tokenized['train'][:3])['input_ids'][0:3]
print([len(x) for x in sample_batch_ids ])
print([len(x) for x in sample_batch_ids_collator ])

#length of each sample without datacollator : [74, 37, 159]
#length of each sample with datacollator    :[159, 159, 159]

[62, 341, 37]
[341, 341, 341]


In [25]:
import torch
from transformers import AutoTokenizer, Gemma3ForCausalLM

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16)


model = Gemma3ForCausalLM.from_pretrained(hugging_face_model_id, 
                                          torch_dtype=torch.bfloat16, 
                                          device_map=gpu_device,
                                          attn_implementation='eager',
                                          quantization_config=bnb_config  )

model.lm_head = torch.nn.Linear(model.config.hidden_size, len(class2id.keys()), bias=False,device=gpu_device)

`torch_dtype` is deprecated! Use `dtype` instead!


In [26]:
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

In [27]:
import bitsandbytes as bnb
def find_all_linear_names(model):
    cls = bnb.nn.Linear4bit #if args.bits == 4 else (bnb.nn.Linear8bitLt if args.bits == 8 else torch.nn.Linear)
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
        if 'lm_head' in lora_module_names: # needed for 16-bit
            lora_module_names.remove('lm_head')
    return list(lora_module_names)

modules = find_all_linear_names(model)
modules = ['gate_proj', 'down_proj', 'v_proj', 'k_proj', 'q_proj', 'o_proj', 'up_proj']

In [28]:
from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=64,
    lora_alpha=32,
    target_modules=modules,
    lora_dropout=0.1,
    bias="none",
    task_type="SEQ_CLS")

model = get_peft_model(model, lora_config)

In [29]:
model.print_trainable_parameters()
#trainable params: 119,209,984 || all params: 3,999,488,512 || trainable%: 2.9806

trainable params: 52,183,040 || all params: 1,052,075,904 || trainable%: 4.9600


In [30]:
class Gemma3ForSequenceClassification(PeftModelForSequenceClassification):
    def __init__(self, peft_config: PeftConfig, model: AutoModelForCausalLM, adapter_name="default"):
        super().__init__(model, peft_config, adapter_name)
        self.num_labels = model.config.num_labels
        self.problem_type = "multi_label_classification" # Assuming multi-label

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        **kwargs):
        
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        outputs = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            **kwargs)

        # Extract logits from the outputs
        logits = outputs.logits

        # select last "real" token and ignore padding tokens

        sequence_lengths   = torch.sum(attention_mask, dim=1)
        last_token_indices = sequence_lengths - 1
        batch_size         = logits.shape[0]
       
        # Get the logits for the last token in the sequence
        logits = logits[torch.arange(batch_size, device=logits.device), last_token_indices, :]
        #logits = logits[:, -1, :] # if batch_size = 1

        loss = None
        if labels is not None:
            if self.problem_type == "regression":
                loss_fct = torch.nn.MSELoss()
                loss = loss_fct(logits.squeeze(), labels.squeeze())
            elif self.problem_type == "single_label_classification":
                loss_fct = torch.nn.CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.problem_type == "multi_label_classification":
                loss_fct = torch.nn.BCEWithLogitsLoss()
                loss = loss_fct(logits, labels.float())

        if not return_dict:
            output = (logits,) + outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions)

In [31]:
peft_config = PeftConfig(peft_type="LORA", task_type="SEQ_CLS", inference_mode=False)
for key, value in lora_config.__dict__.items():
    setattr(peft_config, key, value)

wrapped_model = Gemma3ForSequenceClassification(peft_config, model)
wrapped_model.num_labels = len(class2id.keys())

In [32]:
def custom_binary_crossentropy_loss(logits, labels,epsilon=1e-7):
  
    probs = torch.sigmoid(logits)
    probs = torch.clamp(probs, min=epsilon, max=1-epsilon) # capping values
    loss  = -(labels * torch.log(probs) + (1 - labels) * torch.log(1 - probs))
    return torch.mean(loss)

In [33]:
class CustomTrainer(Trainer):     
    def compute_loss(self, model, inputs,num_items_in_batch=4, return_outputs=False): 
        labels  = inputs.get("labels")
        inputs  = inputs.to(gpu_device)
        outputs = model(**inputs)
        logits  = outputs.logits 
        
        loss    = custom_binary_crossentropy_loss(logits, labels)

        return (loss, outputs) if return_outputs else loss

In [34]:
import evaluate

clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])

def sigmoid(x):
    return 1/(1 + np.exp(-x))

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = sigmoid(predictions)
    predictions = (predictions > 0.5).astype(int).reshape(-1)
    return clf_metrics.compute(predictions=predictions, references=labels.astype(int).reshape(-1))

Downloading builder script: 6.79kB [00:00, 19.7MB/s]
Downloading builder script: 7.56kB [00:00, 20.2MB/s]
Downloading builder script: 7.38kB [00:00, 19.2MB/s]


In [35]:
from transformers import EarlyStoppingCallback, TrainingArguments, Trainer
import os

# Define early stopping and checkpoint directory
early_stop = EarlyStoppingCallback(early_stopping_patience=3, # Increased patience slightly
                                   early_stopping_threshold=0.001) # A small threshold
checkpoints_dir = 'my_classification_gemma_model' # More descriptive name

# Set tokenizer parallelism environment variable
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [36]:
training_args = TrainingArguments(
    gradient_checkpointing=False,  # Gradient Checkpointing ist nicht aktiviert
    gradient_checkpointing_kwargs={"use_reentrant": False},
    logging_strategy="steps",
    logging_steps=100,
    #label_names=classes,
    dataloader_num_workers=4,
    output_dir= checkpoints_dir ,  # Output directory for checkpoints
    learning_rate=5e-5,  # Learning rate for the optimizer
    per_device_train_batch_size=8,  # Batch size per device
    per_device_eval_batch_size=8,  # Batch size per device for evaluation 
    num_train_epochs=3,  # Number of training epochs
    weight_decay=0.01,  # Weight decay for regularization
    eval_strategy='epoch',  # Evaluate after each epoch
    #eval_steps=100,
    save_strategy="epoch",  # Save model checkpoints after each epoch
    load_best_model_at_end=True,  # Load the best model based on the chosen metric
    push_to_hub=False,  # Disable pushing the model to the Hugging Face Hub 
    report_to="tensorboard",  # Disable logging to Weight&Bias
    logging_dir =  f"tensorboard_my_model",
    gradient_accumulation_steps=4,
    fp16=True,
    warmup_ratio =0.05, 
    metric_for_best_model='eval_loss',)  # Metric for selecting the best model 

In [37]:
trainer = Trainer (
    model=wrapped_model,  # The LoRA-adapted model
    args=training_args,  # Training arguments
    train_dataset=dataset_toxic_tokenized['train'],  # Training dataset
    eval_dataset=dataset_toxic_tokenized['valid'],  # Evaluation dataset
    #tokenizer=tokenizer,  # Tokenizer for processing text
    data_collator=data_collator,  # Data collator for preparing batches
    compute_metrics=compute_metrics,  # Function to calculate evaluation metrics
    callbacks=[early_stop]  # Optional early stopping callback
)

In [38]:
trainer.train(resume_from_checkpoint=False)

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
  return fn(*args, **kwargs)


Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1,No log,0.151279,0.961667,0.206897,0.6,0.125
2,No log,0.123017,0.965,0.275862,0.8,0.166667
3,No log,0.105841,0.965,0.322581,0.714286,0.208333


  return fn(*args, **kwargs)
  return fn(*args, **kwargs)


TrainOutput(global_step=96, training_loss=0.6360425551732382, metrics={'train_runtime': 173.4682, 'train_samples_per_second': 17.294, 'train_steps_per_second': 0.553, 'total_flos': 4003299077713920.0, 'train_loss': 0.6360425551732382, 'epoch': 3.0})

In [39]:
def prediction(input_text):
    inputs          = tokenizer(input_text, return_tensors="pt",).to("cuda:0")
    with torch.no_grad():
        outputs = wrapped_model(**inputs).logits
    y_prob          = np.round(np.array(torch.sigmoid(outputs).tolist()[0]),5)
    y_sorted_labels = [id2class.get(y) for y  in np.argsort(y_prob)[::-1]]
    y_prob_sorted   = np.sort(y_prob)[::-1]
    
    return y_sorted_labels,y_prob_sorted  

In [40]:
df_test = pd.DataFrame(dataset_toxic['test'])

df_test['pred'] = df_test['comment_text'].map(prediction)
df_test['argsort_label']  = df_test['pred'].apply(lambda x : x[0])
df_test['argsort_prob']   = df_test['pred'].apply(lambda x : x[1])
print(df_test.shape)
df_test.head(n=2)

(100, 11)


Unnamed: 0,id,comment_text,toxic,severe_toxic,obscene,threat,insult,identity_hate,pred,argsort_label,argsort_prob
0,65546f3958775c00,No. You are not supposed to edit someone else'...,0,0,0,0,0,0,"([toxic, obscene, insult, severe_toxic, identi...","[toxic, obscene, insult, severe_toxic, identit...","[0.00146, 0.00043, 0.00028, 7e-05, 2e-05, 1e-05]"
1,03521147ef578854,Thumbing is when you put your thumb on his pen...,1,0,0,0,0,0,"([toxic, obscene, insult, severe_toxic, identi...","[toxic, obscene, insult, severe_toxic, identit...","[0.20979, 0.07893, 0.0278, 0.00381, 0.0027, 0...."


In [41]:
output_dir =  f'my_awesome_model'
trainer.model.save_pretrained(output_dir)

In [42]:
import json
import os
# Assuming lora_config is the LoraConfig object used during setup
# Assuming hugging_face_model_id is the string ID like "google/gemma-3-4b-it"
# Assuming output_dir is the path where the model was saved

adapter_config_path = os.path.join(output_dir, "adapter_config.json")

# Check if file exists before proceeding
if os.path.exists(adapter_config_path):
    try:
        # Load the potentially incomplete config
        with open(adapter_config_path, 'r') as f:
            saved_config_dict = json.load(f)

        # Get parameters from the original LoraConfig
        # Use.to_dict() if available, otherwise __dict__
        try:
            # Ensure lora_config is the actual LoraConfig object instance
            lora_config_dict = lora_config.to_dict()
        except AttributeError:
            # Fallback, might include extra internal attributes
            lora_config_dict = lora_config.__dict__
            # Clean up potential internal attributes if using __dict__
            lora_config_dict = {k: v for k, v in lora_config_dict.items() if not k.startswith('_')}


        # *** FIX 1: Define the specific keys to check ***
        # These are common LoRA parameters that might be missing
        lora_keys_to_check = [
            "r",
            "lora_alpha",
            "lora_dropout",
            "target_modules",
            "bias",
            "modules_to_save", # Important if you used it
            "fan_in_fan_out",
            "init_lora_weights",
            # Add any other specific keys from your LoraConfig if needed
        ]

        # Merge missing or None parameters from the original lora_config
        updated = False
        for key in lora_keys_to_check:
            # Check if key is missing in saved config OR if it exists but is None
            if key not in saved_config_dict or saved_config_dict[key] is None:
                # Check if the key exists in the original config and has a value
                if key in lora_config_dict and lora_config_dict[key] is not None:
                    saved_config_dict[key] = lora_config_dict[key]
                    updated = True

        # Ensure essential base fields are present and correct
        # Use getattr for safer access to lora_config attributes
        original_task_type = getattr(lora_config, 'task_type', 'SEQ_CLS')
        if 'task_type' not in saved_config_dict or saved_config_dict['task_type']!= original_task_type:
             saved_config_dict['task_type'] = original_task_type
             updated = True

        original_base_model = getattr(lora_config, 'base_model_name_or_path', hugging_face_model_id)
        if 'base_model_name_or_path' not in saved_config_dict or saved_config_dict['base_model_name_or_path']!= original_base_model:
             saved_config_dict['base_model_name_or_path'] = original_base_model
             updated = True

        if 'peft_type' not in saved_config_dict or saved_config_dict['peft_type']!= "LORA":
             saved_config_dict['peft_type'] = "LORA"
             updated = True

        # *** FIX 2: Convert set to list before saving ***
        if 'target_modules' in saved_config_dict and isinstance(saved_config_dict['target_modules'], set):
            saved_config_dict['target_modules'] = sorted(list(saved_config_dict['target_modules'])) # Convert set to sorted list
            updated = True # Mark as updated if conversion happened

        if 'modules_to_save' in saved_config_dict and isinstance(saved_config_dict['modules_to_save'], set):
             # Also handle modules_to_save if it could be a set
             saved_config_dict['modules_to_save'] = sorted(list(saved_config_dict['modules_to_save']))
             updated = True


        # Overwrite the config file only if changes were made
        if updated:
            with open(adapter_config_path, 'w') as f:
                # Save the corrected dictionary as JSON
                json.dump(saved_config_dict, f, indent=2)
            print(f"Manually updated adapter configuration: {adapter_config_path}")
            print("New content:", saved_config_dict) # Optional: print the final dict
        else:
            print(f"Adapter configuration already seemed complete or no changes needed: {adapter_config_path}")

    except Exception as e:
        print(f"Error during manual update of adapter_config.json: {e}")
else:
    print(f"Error: adapter_config.json not found at {adapter_config_path}")

Manually updated adapter configuration: my_awesome_model/adapter_config.json
New content: {'auto_mapping': None, 'base_model_name_or_path': 'google/gemma-3-1b-it', 'inference_mode': True, 'peft_type': 'LORA', 'revision': None, 'task_type': 'SEQ_CLS', 'r': 64, 'lora_alpha': 32, 'lora_dropout': 0.1, 'target_modules': ['down_proj', 'gate_proj', 'k_proj', 'o_proj', 'q_proj', 'up_proj', 'v_proj'], 'bias': 'none', 'modules_to_save': ['classifier', 'score', 'classifier', 'score'], 'fan_in_fan_out': False, 'init_lora_weights': True}


In [43]:
output_dir =  'my_awesome_model'
hugging_face_model_id = "google/gemma-3-1b-it" # gemma-3-4b-it
gpu_device = 'cuda:0'

from transformers import AutoTokenizer 
import torch
from transformers import AutoTokenizer, Gemma3ForCausalLM, BitsAndBytesConfig

tokenizer = AutoTokenizer.from_pretrained(hugging_face_model_id,
                                          padding_side='right',
                                          device_map=gpu_device,
                                          add_bos=True)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

class2id = {'toxic':0,'severe_toxic':1,'obscene':2,'threat':3,'insult':4,'identity_hate':5}
id2class = {v: k for k, v in class2id.items()}




bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16)

base_model = Gemma3ForCausalLM.from_pretrained(hugging_face_model_id, 
                                          torch_dtype=torch.bfloat16, 
                                          device_map=gpu_device,
                                          attn_implementation='eager',
                                          quantization_config=bnb_config  )

In [44]:
from peft import LoraConfig, get_peft_model

modules = ['gate_proj', 'down_proj', 'v_proj', 'k_proj', 'q_proj', 'o_proj', 'up_proj']

lora_config = LoraConfig(
    r=64,
    lora_alpha=32,
    target_modules=modules,
    lora_dropout=0.1,
    bias="none",
    task_type="SEQ_CLS",modules_to_save=['lm_head'])

peft_config = PeftConfig(peft_type="LORA", task_type="SEQ_CLS", inference_mode=False)
for key, value in lora_config.__dict__.items():
    setattr(peft_config, key, value)


num_labels = len(id2class.keys()) # Must match the number of classes used during training
load_dtype = torch.bfloat16 # Match training or desired inference precision
print(f"Replacing lm_head for {num_labels} classes.")
base_model.lm_head = torch.nn.Linear(
    base_model.config.hidden_size,
    num_labels,
    bias=False,
    device=base_model.device # Ensure head is on the correct device
).to(dtype=load_dtype) # Ensure head matches model dtype


base_model = Gemma3ForSequenceClassification(peft_config, base_model)

Replacing lm_head for 6 classes.


In [45]:
model = PeftModel.from_pretrained(
    base_model,
    output_dir,
    is_trainable=False # Set to False for inference
)

model.eval()

PeftModelForSequenceClassification(
  (base_model): LoraModel(
    (model): Gemma3ForSequenceClassification(
      (base_model): LoraModel(
        (model): Gemma3ForCausalLM(
          (model): Gemma3TextModel(
            (embed_tokens): Gemma3TextScaledWordEmbedding(262144, 1152, padding_idx=0)
            (layers): ModuleList(
              (0-25): 26 x Gemma3DecoderLayer(
                (self_attn): Gemma3Attention(
                  (q_proj): lora.Linear4bit(
                    (base_layer): Linear4bit(in_features=1152, out_features=1024, bias=False)
                    (lora_dropout): ModuleDict(
                      (default): Dropout(p=0.1, inplace=False)
                    )
                    (lora_A): ModuleDict(
                      (default): Linear(in_features=1152, out_features=64, bias=False)
                    )
                    (lora_B): ModuleDict(
                      (default): Linear(in_features=64, out_features=1024, bias=False)
                    )


In [46]:
def prediction(input_text):
    inputs          = tokenizer(input_text, return_tensors="pt",).to("cuda:0")
    with torch.no_grad():
        outputs = model(**inputs).logits
    y_prob          = np.round(np.array(torch.sigmoid(outputs).tolist()[0]),5)
    y_sorted_labels = [id2class.get(y) for y  in np.argsort(y_prob)[::-1]]
    y_prob_sorted   = np.sort(y_prob)[::-1]
    
    return y_sorted_labels,y_prob_sorted 

In [47]:
example = ''' "Who the fuck are you? 
his fee was an umberella it was a joke made by himself i have sources let me post em up it was on SKY SPORTS NEWS. 
He was joking about the rain in manchester. So how the FUCK is that vandelising" '''
prediction(example)

(['obscene', 'insult', 'toxic', 'severe_toxic', 'threat', 'identity_hate'],
 array([0.96484, 0.96094, 0.76953, 0.40625, 0.2793 , 0.08398]))