### Prepare Dataset

In [1]:
from datasets import Dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_data = [
    {"raw_query": "phone cheap", "refined_query": "affordable smartphones under $300"},
    {
        "raw_query": "red dress party",
        "refined_query": "elegant red evening dresses for parties",
    },
]

In [4]:
eval_data = [
    {
        "raw_query": "laptop fast",
        "refined_query": "high-performance laptops with fast processors",
    },
]

### Create Hugging Face Dataset

In [5]:
train_dataset = Dataset.from_dict(
    {
        "raw_query": [item["raw_query"] for item in train_data],
        "refined_query": [item["refined_query"] for item in train_data],
    }
)

In [6]:
train_dataset

Dataset({
    features: ['raw_query', 'refined_query'],
    num_rows: 2
})

In [7]:
eval_dataset = Dataset.from_dict(
    {
        "raw_query": [item["raw_query"] for item in eval_data],
        "refined_query": [item["refined_query"] for item in eval_data],
    }
)

In [8]:
eval_dataset

Dataset({
    features: ['raw_query', 'refined_query'],
    num_rows: 1
})

### Preprocess The Data

In [9]:
from transformers import AutoTokenizer

In [10]:
max_input_length = 128
max_target_length = 128

In [11]:
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")

In [12]:
def preprocess(examples):
    inputs = ["refine e-commerce query: " + query for query in examples["raw_query"]]
    targets = examples["refined_query"]

    model_inputs = tokenizer(
        inputs, max_length=max_input_length, truncation=True, padding="max_length"
    )

    labels = tokenizer(
        targets, max_length=max_target_length, truncation=True, padding="max_length"
    )

    model_inputs["labels"] = labels["input_ids"]

    return model_inputs

In [13]:
tokenized_train = train_dataset.map(preprocess, batched=True)

Map: 100%|██████████| 2/2 [00:00<00:00, 26.45 examples/s]


In [14]:
tokenized_train

Dataset({
    features: ['raw_query', 'refined_query', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 2
})

In [15]:
tokenized_eval = eval_dataset.map(preprocess, batched=True)

Map: 100%|██████████| 1/1 [00:00<00:00, 106.49 examples/s]


In [16]:
tokenized_eval

Dataset({
    features: ['raw_query', 'refined_query', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 1
})

### Set Up Training Arguments

In [17]:
from transformers import Seq2SeqTrainingArguments

In [19]:
training_args = Seq2SeqTrainingArguments(
    eval_strategy="epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=5,
    predict_with_generate=True,
    fp16=True,
    report_to="none"
)

### Initialize Model

In [20]:
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq

In [21]:
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")

In [22]:
model

T5ForConditionalGeneration(
  (shared): Embedding(32128, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear(in_features=768, out_features=2048, bias=False)
              (wi_1): Linear(in_features=768, out_features=2048, bias=False)
              (wo):

In [23]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [24]:
data_collator

DataCollatorForSeq2Seq(tokenizer=T5TokenizerFast(name_or_path='google/flan-t5-base', vocab_size=32100, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'additional_special_tokens': ['<extra_id_0>', '<extra_id_1>', '<extra_id_2>', '<extra_id_3>', '<extra_id_4>', '<extra_id_5>', '<extra_id_6>', '<extra_id_7>', '<extra_id_8>', '<extra_id_9>', '<extra_id_10>', '<extra_id_11>', '<extra_id_12>', '<extra_id_13>', '<extra_id_14>', '<extra_id_15>', '<extra_id_16>', '<extra_id_17>', '<extra_id_18>', '<extra_id_19>', '<extra_id_20>', '<extra_id_21>', '<extra_id_22>', '<extra_id_23>', '<extra_id_24>', '<extra_id_25>', '<extra_id_26>', '<extra_id_27>', '<extra_id_28>', '<extra_id_29>', '<extra_id_30>', '<extra_id_31>', '<extra_id_32>', '<extra_id_33>', '<extra_id_34>', '<extra_id_35>', '<extra_id_36>', '<extra_id_37>', '<extra_id_38>', '<extra_id_39>', '<extra_id_40>', '<extra_id_41>',

### Create Trainer and Start Training

In [25]:
from transformers import Seq2SeqTrainer

In [35]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_eval,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

  trainer = Seq2SeqTrainer(


In [36]:
trainer

<transformers.trainer_seq2seq.Seq2SeqTrainer at 0x7fcb0d59e350>

In [37]:
trainer.train()

Epoch,Training Loss,Validation Loss
1,No log,32.473007
2,No log,31.396206
3,No log,30.618795
4,No log,30.104927
5,No log,29.860668


TrainOutput(global_step=5, training_loss=30.57262268066406, metrics={'train_runtime': 63.1565, 'train_samples_per_second': 0.158, 'train_steps_per_second': 0.079, 'total_flos': 1711893381120.0, 'train_loss': 30.57262268066406, 'epoch': 5.0})

In [38]:
model.save_pretrained("../app/fine_tune_vault/flan-t5-query-refiner-model")

In [39]:
tokenizer.save_pretrained("../app/fine_tune_vault/flan-t5-query-refiner-token")

('../app/fine_tune_vault/flan-t5-query-refiner-token/tokenizer_config.json',
 '../app/fine_tune_vault/flan-t5-query-refiner-token/special_tokens_map.json',
 '../app/fine_tune_vault/flan-t5-query-refiner-token/tokenizer.json')