In [1]:
import os
os.chdir("../")

In [3]:
from peft import PromptTuningConfig, TaskType
from src.hf_models.get_model import get_distilbert, get_distilgpt2


prefix_config = PromptTuningConfig(
    task_type=TaskType.SEQ_CLS,
    num_virtual_tokens=10,
    prompt_tuning_init_text="predict the class "
    
)

In [6]:
model, tokenizer, config = get_distilgpt2(task="SequenceClassification",
                                          num_labels=77)

tokenizer.pad_token = tokenizer.eos_token

Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at /home/tess/work/deep_learning/transformers/models/distilgpt2/model and are newly initialized because the shapes did not match:
- score.weight: found shape torch.Size([1, 768]) in the checkpoint and torch.Size([77, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model loaded successfully.


In [7]:
config

GPT2Config {
  "_num_labels": 1,
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "id2label": {
    "0": "LABEL_0"
  },
  "initializer_range": 0.02,
  "label2id": {
    "LABEL_0": 0
  },
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 768,
  "n_head": 12,
  "n_inner": null,
  "n_layer": 6,
  "n_positions": 1024,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50
    }
  },
  "transformers_version": "4.47.1",
  "use_cache": true,
  "vocab_size": 50257
}

In [8]:
from peft import get_peft_model

model = get_peft_model(model,peft_config=prefix_config)


In [9]:
from src.hf_dataset.dataset import get_banking_77
ds = get_banking_77()
def encode_batch(batch):
  """Encodes a batch of input data using the model tokenizer."""
  return tokenizer(batch["text"], max_length=80, truncation=True, padding="max_length")

# Encode the input data
ds = ds.map(encode_batch, batched=True)
# The transformers model expects the target class column to be named "labels"
ds = ds.rename_column(original_column_name="label", new_column_name="labels")
# Transform to pytorch tensors and only output the required columns
ds.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

ds["train"] = ds["train"].remove_columns(["text"])
ds["test"] = ds["test"].remove_columns(["text"])

In [16]:
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    learning_rate=1e-4,
    num_train_epochs=1,
    fp16=False
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=ds["train"],
    
)

In [12]:
trainer.train()

Step,Training Loss
500,5.6661
1000,4.5938
1500,4.4082
2000,4.3993
2500,4.3813
3000,4.3959
3500,4.3706
4000,4.3639
4500,4.3646
5000,4.3661


TrainOutput(global_step=10003, training_loss=4.438824054420942, metrics={'train_runtime': 277.731, 'train_samples_per_second': 36.017, 'train_steps_per_second': 36.017, 'total_flos': 204767203737600.0, 'train_loss': 4.438824054420942, 'epoch': 1.0})

In [17]:
trainer.evaluate(ds["test"])

{'eval_loss': 4.380784511566162,
 'eval_model_preparation_time': 0.0012,
 'eval_runtime': 33.8222,
 'eval_samples_per_second': 91.064,
 'eval_steps_per_second': 91.064}

In [18]:
prompt= "I want to check my balance"
input_ids = tokenizer(prompt, return_tensors="pt")

In [19]:
input_ids = input_ids.to(model.device)

In [20]:
import torch
with torch.no_grad():
    logits = model(**input_ids).logits

In [21]:
predicted_class_id = logits.argmax().item()

In [22]:
predicted_class_id

28