# Preference Alignment with Direct Preference Optimization (DPO)

This notebook will guide you through the process of fine-tuning a language model using Direct Preference Optimization (DPO). We will use the SmolLM2-135M-Instruct model which has already been through a SFT training, so it it compatible with DPO. You can also use the model you trained in [1_instruction_tuning](../../1_instruction_tuning/notebooks/sft_finetuning_example.ipynb).

<div style='background-color: lightblue; padding: 10px; border-radius: 5px; margin-bottom: 20px; color:black'>
     <h2 style='margin: 0;color:blue'>Exercise: Aligning SmolLM2 with DPOTrainer</h2>
     <p>Take a dataset from the Hugging Face hub and align a model on it. </p> 
     <p><b>Difficulty Levels</b></p>
     <p>🐢 Use the `trl-lib/ultrafeedback_binarized` dataset</p>
     <p>🐕 Try out the `argilla/ultrafeedback-binarized-preferences` dataset</p>
     <p>🦁 Select a dataset that relates to a real-world use case you’re interested in, or use the model you trained in 
        <a href="../../1_instruction_tuning/notebooks/sft_finetuning_example.ipynb">1_instruction_tuning</a></p>
</div>

In [1]:
# Install the requirements in Google Colab
# !pip install transformers datasets trl huggingface_hub

# Authenticate to Hugging Face

from huggingface_hub import login

login()

# for convenience you can create an environment variable containing your hub token as HF_TOKEN

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

## Import libraries


In [2]:
import torch
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import DPOTrainer, DPOConfig


A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.0.2 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "/Users/lucaviano/opt/anaconda3/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/Users/lucaviano/opt/anaconda3/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/Users/lucaviano/.amazon/lib/python3.9/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/Users/lucaviano/.amazon/lib/python3.9/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.s

## Format dataset

In [3]:
# Load dataset

# TODO: 🦁🐕 change the dataset to one of your choosing

In [4]:
dataset[0]

{'chosen': [{'content': 'Use the pygame library to write a version of the classic game Snake, with a unique twist',
   'role': 'user'},
  {'content': "Sure, I'd be happy to help you write a version of the classic game Snake using the pygame library! Here's a basic outline of how we can approach this:\n\n1. First, we'll need to set up the game display and create a game object that we can use to handle the game's state.\n2. Next, we'll create the game's grid, which will be used to represent the game board. We'll need to define the size of the grid and the spaces within it.\n3. After that, we'll create the snake object, which will be used to represent the player's movement. We'll need to define the size of the snake and the speed at which it moves.\n4. We'll also need to create a food object, which will be used to represent the food that the player must collect to score points. We'll need to define the location of the food and the speed at which it moves.\n5. Once we have these objects se

In [38]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, get_scheduler
from torch.optim import AdamW
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from datasets import load_dataset
from tqdm import tqdm
import os
import math

# --- Configuration ---
MODEL_ID = "HuggingFaceTB/SmolLM2-135M-Instruct" # A small model for quick demonstration
DATASET_ID = "HuggingFaceH4/ultrafeedback_binarized" #"HuggingFaceH4/ultrafeedback_binarized"
OUTPUT_DIR = "./dpo_custom_tinyllama_ultrafeedback"

# Training parameters
NUM_TRAIN_EXAMPLES = 1000 # Use a small subset for demonstration
NUM_EVAL_EXAMPLES = 200
EPOCHS = 1
LEARNING_RATE = 1e-5 # DPO often uses a lower learning rate than SFT
BETA = 0.1 # DPO beta parameter, controls the strength of the preference. Common values: 0.1, 0.5, 0.8
BETA1 = 0.3
BATCH_SIZE = 2
GRADIENT_ACCUMULATION_STEPS = 4 # Effective batch size = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS = 8
MAX_LENGTH = 512 # Max total sequence length (prompt + response)
MAX_PROMPT_LENGTH = 256 # Max prompt length

# Device setup

if torch.cuda.is_available():
    DEVICE = "cuda"
    DTYPE = torch.bfloat16 # bfloat16 is usually best for NVIDIA GPUs (Ampere architecture and newer)
elif torch.backends.mps.is_available():
    DEVICE = "mps"
    DTYPE = torch.float16 # MPS typically supports float16 (half-precision), but not bfloat16.
                         # If float16 causes issues, fall back to torch.float32
    print("Using MPS backend. Note: BFloat16 is not supported on MPS, using Float16.")
else:
    DEVICE = "cpu"
    DTYPE = torch.float32 # CPU runs best with float32

print(f"Selected device: {DEVICE} with dtype: {DTYPE}")
# --- 1. Load Models and Tokenizer ---
print(f"Loading model: {MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# Crucial for padding and chat templates
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Policy Model (will be trained with LoRA)
# Use prepare_model_for_kbit_training if you're using quantization (e.g., 4-bit)
print(DTYPE)
policy_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=DTYPE,
    device_map=DEVICE
)
policy_model.config.use_cache = False # Required for gradient checkpointing, often helpful for training
policy_model.train() # Set to train mode for gradients

# Apply LoRA to the policy model
peft_config = LoraConfig(
    r=16,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules="all-linear", # or specify: ["q_proj", "v_proj", "k_proj", "o_proj"]
)
policy_model = get_peft_model(policy_model, peft_config)
policy_model.print_trainable_parameters()

# Reference Model (frozen copy of the initial SFT model)
# Ensure this model is *not* trained and is in eval mode.
ref_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=DTYPE,
    device_map=DEVICE
)

ref_model.eval() # Set to eval mode for no gradients and no dropout
for param in ref_model.parameters():
    param.requires_grad = False
print("Policy model and reference model loaded.")

# --- 2. Data Preparation ---
print(f"Loading dataset: {DATASET_ID}...")

dataset = load_dataset(path=DATASET_ID, split="train_sft")

# For demonstration, select a small subset
if NUM_TRAIN_EXAMPLES:
    dataset = dataset.shuffle(seed=42)
    train_dataset_raw = dataset.select(range(NUM_TRAIN_EXAMPLES))
    eval_dataset_raw = dataset.select(range(NUM_TRAIN_EXAMPLES, NUM_TRAIN_EXAMPLES + NUM_EVAL_EXAMPLES))
else:
    train_dataset_raw = dataset.train_test_split(test_size=0.1, seed=42)['train']
    eval_dataset_raw = dataset.train_test_split(test_size=0.1, seed=42)['test']

print(f"Loaded {len(train_dataset_raw)} training examples and {len(eval_dataset_raw)} evaluation examples.")


Using MPS backend. Note: BFloat16 is not supported on MPS, using Float16.
Selected device: mps with dtype: torch.float16
Loading model: HuggingFaceTB/SmolLM2-135M-Instruct...
torch.float16
trainable params: 4,884,480 || all params: 139,399,488 || trainable%: 3.5039
Policy model and reference model loaded.
Loading dataset: HuggingFaceH4/ultrafeedback_binarized...
Loaded 1000 training examples and 200 evaluation examples.


In [20]:
dataset[0]

{'prompt': 'Do you know something about crystallography and structure factor?',
 'prompt_id': '6cc01e0932f2dc27f2e6bb95e9d6c20de1be8c40c1ad17f83f9899a15d3cf195',
 'chosen': [{'content': 'Do you know something about crystallography and structure factor?',
   'role': 'user'},
  {'content': 'Crystallography is the science of the arrangement of atoms in solids. It is a vast and interdisciplinary field that has applications in physics, chemistry, materials science, biology, and engineering.\n\nThe structure factor is a mathematical function that is used to describe the diffraction of waves by a crystal. It is a complex number that is related to the atomic positions in the crystal.\n\nThe structure factor can be used to calculate the intensity of the diffracted waves. This information can be used to determine the atomic positions in the crystal and to study the structure of materials.\n\nCrystallography is a powerful tool for understanding the structure of materials. It has been used to dete

In [21]:

def preprocess_function(examples):
    processed = {
        "prompt_input_ids": [],
        "chosen_input_ids": [],
        "rejected_input_ids": [],
        "prompt_attention_mask": [],
        "chosen_attention_mask": [],
        "rejected_attention_mask": [],
        "prompt_len": []
    }

    for i in range(len(examples['prompt'])):
        current_prompt_messages = examples['prompt'][i]
        current_chosen_messages = examples['chosen'][i]
        current_rejected_messages = examples['rejected'][i]

        # --- START FIX FOR TYPEERROR ---
        # Robustness check: Ensure messages are lists of dictionaries.
        # ultrafeedback-binarized is supposed to have this format.
        # If it's a string, it's malformed data, or something has corrupted the dataset.
        
        # Helper to validate and potentially fix message lists
        def ensure_message_list(messages, is_prompt=False, idx=i):
            if isinstance(messages, list) and all(isinstance(m, dict) and 'role' in m and 'content' in m for m in messages):
                return messages
            elif isinstance(messages, str):
                # If it's a string, try to wrap it as a simple user message.
                # This is a heuristic for malformed data; assumes the string is the user's input.
                if is_prompt:
                    print(f"Warning: Prompt entry {idx} is a string. Wrapping as 'user' message.")
                    return [{"role": "user", "content": messages}]
                else: # Chosen/rejected responses should not be simple strings
                    print(f"Warning: Chosen/Rejected entry {idx} is a string, which is unexpected. Skipping example.")
                    return None
            else:
                # If it's not a list or a string, it's an unrecognized format.
                print(f"Warning: Malformed entry {idx} (type: {type(messages)}). Skipping example.")
                return None

        current_prompt_messages = ensure_message_list(current_prompt_messages, is_prompt=True)
        current_chosen_messages = ensure_message_list(current_chosen_messages)
        current_rejected_messages = ensure_message_list(current_rejected_messages)

        if current_prompt_messages is None or current_chosen_messages is None or current_rejected_messages is None:
            continue # Skip this example if any part is malformed
        # --- END FIX FOR TYPEERROR ---


        # Format prompt for DPO: add empty assistant turn
        prompt_with_assistant_turn = current_prompt_messages + [{"role": "assistant", "content": ""}]

        prompt_str = tokenizer.apply_chat_template(
            prompt_with_assistant_turn,
            tokenize=False,
            add_generation_prompt=True
        )
        
        # Format chosen and rejected responses (full conversation)
        chosen_str = tokenizer.apply_chat_template(
            current_chosen_messages,
            tokenize=False
        )
        rejected_str = tokenizer.apply_chat_template(
            current_rejected_messages,
            tokenize=False
        )

        # Tokenize (don't pad here, DataCollator will handle it)
        prompt_encoded = tokenizer(prompt_str, truncation=True, max_length=MAX_PROMPT_LENGTH)
        chosen_encoded = tokenizer(chosen_str, truncation=True, max_length=MAX_LENGTH)
        rejected_encoded = tokenizer(rejected_str, truncation=True, max_length=MAX_LENGTH)

        # Filter out examples that are too long after tokenization
        if (len(prompt_encoded['input_ids']) >= MAX_PROMPT_LENGTH or
            len(chosen_encoded['input_ids']) >= MAX_LENGTH or
            len(rejected_encoded['input_ids']) >= MAX_LENGTH):
            # print(f"Skipping example due to length: Prompt {len(prompt_encoded['input_ids'])}, Chosen {len(chosen_encoded['input_ids'])}, Rejected {len(rejected_encoded['input_ids'])}")
            continue

        processed["prompt_input_ids"].append(prompt_encoded["input_ids"])
        processed["chosen_input_ids"].append(chosen_encoded["input_ids"])
        processed["rejected_input_ids"].append(rejected_encoded["input_ids"])
        processed["prompt_attention_mask"].append(prompt_encoded["attention_mask"])
        processed["chosen_attention_mask"].append(chosen_encoded["attention_mask"])
        processed["rejected_attention_mask"].append(rejected_encoded["attention_mask"])
        processed["prompt_len"].append(len(prompt_encoded["input_ids"]))

    return processed


print("Preprocessing dataset (applying chat template and tokenizing)...")
train_dataset = train_dataset_raw.map(
    preprocess_function,
    batched=True,
    remove_columns=train_dataset_raw.column_names,
    num_proc=os.cpu_count(),
    desc="Preprocessing train dataset"
)
eval_dataset = eval_dataset_raw.map(
    preprocess_function,
    batched=True,
    remove_columns=eval_dataset_raw.column_names,
    num_proc=os.cpu_count(),
    desc="Preprocessing eval dataset"
)

# Convert lists to tensors for DataLoader
train_dataset.set_format(type="torch", columns=['prompt_input_ids', 'chosen_input_ids', 'rejected_input_ids',
                                                'prompt_attention_mask', 'chosen_attention_mask', 'rejected_attention_mask', 'prompt_len'])
eval_dataset.set_format(type="torch", columns=['prompt_input_ids', 'chosen_input_ids', 'rejected_input_ids',
                                               'prompt_attention_mask', 'chosen_attention_mask', 'rejected_attention_mask', 'prompt_len'])


print(f"After preprocessing: {len(train_dataset)} training examples and {len(eval_dataset)} evaluation examples.")

# Custom Data Collator for DPO
class DPODataCollator:
    def __init__(self, tokenizer, max_length=MAX_LENGTH, max_prompt_length=MAX_PROMPT_LENGTH):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.max_prompt_length = max_prompt_length

    def __call__(self, features):
        batch = {}
        for key in features[0].keys():
            batch[key] = [f[key] for f in features]

        # Pad sequences
        batch['prompt_input_ids'] = self.tokenizer.pad(
            {'input_ids': batch['prompt_input_ids'], 'attention_mask': batch['prompt_attention_mask']},
            padding='longest',
            max_length=self.max_prompt_length,
            return_tensors='pt',
        )['input_ids']
        batch['chosen_input_ids'] = self.tokenizer.pad(
            {'input_ids': batch['chosen_input_ids'], 'attention_mask': batch['chosen_attention_mask']},
            padding='longest',
            max_length=self.max_length,
            return_tensors='pt',
        )['input_ids']
        batch['rejected_input_ids'] = self.tokenizer.pad(
            {'input_ids': batch['rejected_input_ids'], 'attention_mask': batch['rejected_attention_mask']},
            padding='longest',
            max_length=self.max_length,
            return_tensors='pt',
        )['input_ids']

        # Also pad attention masks
        batch['prompt_attention_mask'] = self.tokenizer.pad(
            {'input_ids': [torch.ones(len(ids), dtype=torch.long) for ids in [f['prompt_input_ids'] for f in features]], 'attention_mask': batch['prompt_attention_mask']}, # Use actual lengths here
            padding='longest',
            max_length=self.max_prompt_length,
            return_tensors='pt',
        )['attention_mask']
        batch['chosen_attention_mask'] = self.tokenizer.pad(
            {'input_ids': [torch.ones(len(ids), dtype=torch.long) for ids in [f['chosen_input_ids'] for f in features]], 'attention_mask': batch['chosen_attention_mask']},
            padding='longest',
            max_length=self.max_length,
            return_tensors='pt',
        )['attention_mask']
        batch['rejected_attention_mask'] = self.tokenizer.pad(
            {'input_ids': [torch.ones(len(ids), dtype=torch.long) for ids in [f['rejected_input_ids'] for f in features]], 'attention_mask': batch['rejected_attention_mask']},
            padding='longest',
            max_length=self.max_length,
            return_tensors='pt',
        )['attention_mask']

        # Convert prompt_len to tensor
        batch['prompt_len'] = torch.tensor(batch['prompt_len'], dtype=torch.long)

        return batch

data_collator = DPODataCollator(tokenizer)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=data_collator
)
eval_dataloader = DataLoader(
    eval_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=data_collator
)

# --- 3. Helper Function to Calculate Log Probabilities ---
def get_log_probs(model, input_ids, attention_mask, prompt_len):
    """
    Calculates the log probability of a sequence of tokens given a model,
    masking out the prompt part and padding.

    Args:
        model: The language model (policy or reference).
        input_ids: Tensor of tokenized sequence (prompt + response).
        attention_mask: Tensor of attention mask for the sequence.
        prompt_len: Tensor of lengths of the prompt for each example in batch.

    Returns:
        A tensor of shape (batch_size,) containing the sum of log probabilities
        for the response tokens only.
    """
    with torch.no_grad() if model.training is False else torch.enable_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits # (batch_size, sequence_length, vocab_size)

    # Shift logits and labels for causal LM
    # The loss is computed for token_i given token_0 to token_{i-1}
    # So, logits[:, :-1, :] corresponds to predicting token_1 to token_{length-1}
    # And input_ids[:, 1:] are the actual token_1 to token_{length-1}
    logits = logits[:, :-1, :]
    labels = input_ids[:, 1:]
    
    # Calculate log_softmax over the vocabulary dimension
    log_probs = F.log_softmax(logits, dim=-1) # (batch_size, sequence_length - 1, vocab_size)

    # Gather the log probabilities for the actual next tokens
    # `labels.unsqueeze(-1)` makes it (batch_size, sequence_length - 1, 1)
    # `log_probs.gather(dim=-1, index=...)` picks the log_prob at the label index
    # `squeeze(-1)` removes the last dimension, resulting in (batch_size, sequence_length - 1)
    token_log_probs = torch.gather(log_probs, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)

    # Create a mask to only consider response tokens
    # tokens corresponding to prompt_len[i] up to the end of the sequence.
    # The mask needs to be shifted by 1 because `token_log_probs` is also shifted.
    sequence_lengths = attention_mask.sum(dim=-1) # Actual length of each sequence before padding
    
    # Create an index tensor for each position in the shifted sequence
    indices = torch.arange(token_log_probs.shape[1], device=token_log_probs.device).unsqueeze(0) # (1, seq_len-1)

    # Mask for response tokens: True if index >= prompt_len (shifted by 1) AND index < sequence_length (shifted by 1)
    response_mask = (indices >= (prompt_len - 1).unsqueeze(1)) & \
                    (indices < (sequence_lengths - 1).unsqueeze(1)) & \
                    (labels != tokenizer.pad_token_id) # Also exclude padding tokens explicitly

    # Apply the mask
    masked_log_probs = token_log_probs * response_mask.float()
    
    # Sum the log probabilities for each example
    return masked_log_probs.sum(dim=-1) # (batch_size,)

# --- 4. Optimizer and Scheduler ---
optimizer = AdamW(policy_model.parameters(), lr=LEARNING_RATE)

num_training_steps = (len(train_dataloader) // GRADIENT_ACCUMULATION_STEPS) * EPOCHS
lr_scheduler = get_scheduler(
    name="cosine",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

# --- 5. Training Loop ---
print("Starting DPO training loop...")
global_step = 0
policy_model.zero_grad()

for epoch in range(EPOCHS):
    policy_model.train() # Ensure policy model is in train mode
    total_loss = 0
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{EPOCHS} Training")

    for step, batch in enumerate(progress_bar):
        # Move batch to device
        batch = {k: v.to(DEVICE) for k, v in batch.items()}

        # Compute log probabilities for chosen responses
        log_prob_chosen_policy = get_log_probs(policy_model, batch['chosen_input_ids'], batch['chosen_attention_mask'], batch['prompt_len'])
        with torch.no_grad(): # Ensure no gradients for reference model
            log_prob_chosen_ref = get_log_probs(ref_model, batch['chosen_input_ids'], batch['chosen_attention_mask'], batch['prompt_len'])

        # Compute log probabilities for rejected responses
        log_prob_rejected_policy = get_log_probs(policy_model, batch['rejected_input_ids'], batch['rejected_attention_mask'], batch['prompt_len'])
        with torch.no_grad():
            log_prob_rejected_ref = get_log_probs(ref_model, batch['rejected_input_ids'], batch['rejected_attention_mask'], batch['prompt_len'])

        # Calculate the DPO loss components
        pi_log_ratio = log_prob_chosen_policy - log_prob_rejected_policy
        ref_log_ratio = log_prob_chosen_ref - log_prob_rejected_ref

        dpo_loss_components = -F.logsigmoid(BETA * (pi_log_ratio - ref_log_ratio))
        
        # Average loss over the batch
        loss = dpo_loss_components.mean()
        
        # Backward pass with gradient accumulation
        loss = loss / GRADIENT_ACCUMULATION_STEPS
        loss.backward()

        total_loss += loss.item() * GRADIENT_ACCUMULATION_STEPS # Scale back up for logging

        if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0 or (step + 1) == len(train_dataloader):
            # Clip gradients to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(policy_model.parameters(), 1.0)
            
            optimizer.step()
            lr_scheduler.step()
            policy_model.zero_grad()
            global_step += 1
            
            progress_bar.set_postfix({
                "loss": total_loss / (step + 1),
                "learning_rate": lr_scheduler.get_last_lr()[0],
                "global_step": global_step
            })

    print(f"Epoch {epoch+1} finished. Average Training Loss: {total_loss / len(train_dataloader)}")

    # --- Evaluation ---
    policy_model.eval()
    eval_loss = 0
    eval_progress_bar = tqdm(eval_dataloader, desc=f"Epoch {epoch+1}/{EPOCHS} Evaluation")
    with torch.no_grad():
        for batch in eval_progress_bar:
            batch = {k: v.to(DEVICE) for k, v in batch.items()}

            log_prob_chosen_policy = get_log_probs(policy_model, batch['chosen_input_ids'], batch['chosen_attention_mask'], batch['prompt_len'])
            log_prob_chosen_ref = get_log_probs(ref_model, batch['chosen_input_ids'], batch['chosen_attention_mask'], batch['prompt_len'])

            log_prob_rejected_policy = get_log_probs(policy_model, batch['rejected_input_ids'], batch['rejected_attention_mask'], batch['prompt_len'])
            log_prob_rejected_ref = get_log_probs(ref_model, batch['rejected_input_ids'], batch['rejected_attention_mask'], batch['prompt_len'])

            pi_log_ratio = log_prob_chosen_policy - log_prob_rejected_policy
            ref_log_ratio = log_prob_chosen_ref - log_prob_rejected_ref

            dpo_loss_components = -F.logsigmoid(BETA * (pi_log_ratio - ref_log_ratio))
            eval_loss += dpo_loss_components.mean().item()
            eval_progress_bar.set_postfix({"eval_loss": eval_loss / (eval_progress_bar.n + 1)})
            
    print(f"Epoch {epoch+1} finished. Average Evaluation Loss: {eval_loss / len(eval_dataloader)}")

# --- Save the fine-tuned model ---
# Save the LoRA adapter
final_model_path = os.path.join(OUTPUT_DIR, "final_checkpoint")
policy_model.save_pretrained(final_model_path)
tokenizer.save_pretrained(final_model_path)
print(f"Model saved to {final_model_path}")

print("DPO training complete!")

# --- Optional: Test the trained model ---
if DEVICE == "cuda":
    print("\n--- Testing the trained model ---")
    from transformers import pipeline
    from peft import PeftModel

    # Load the base model
    test_model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=DTYPE,
        device_map=DEVICE
    )
    # Load the LoRA adapter and merge
    test_model = PeftModel.from_pretrained(test_model, final_model_path)
    test_model = test_model.merge_and_unload() # Merge LoRA weights into base model
    test_model.eval()

    pipe = pipeline(
        "text-generation",
        model=test_model,
        tokenizer=tokenizer,
        torch_dtype=DTYPE,
        device=0
    )

    test_prompt_message = [{"role": "user", "content": "Write a short, heartwarming story about an old cat."}]
    test_prompt = tokenizer.apply_chat_template(
        test_prompt_message,
        tokenize=False,
        add_generation_prompt=True
    )

    print(f"Generating response for prompt:\n{test_prompt}")
    
    outputs = pipe(
        test_prompt,
        max_new_tokens=100,
        do_sample=True,
        temperature=0.7,
        top_k=50,
        top_p=0.95,
        repetition_penalty=1.1,
        eos_token_id=tokenizer.eos_token_id
    )
    print("\nGenerated Response (full):")
    print(outputs[0]['generated_text'])
    
    generated_text_only = outputs[0]['generated_text'].replace(test_prompt, '').strip()
    print("\nGenerated Response (clean):")
    print(generated_text_only)
else:
    print("\nSkipping model testing: CUDA not available. Run on GPU for testing.")

Preprocessing dataset (applying chat template and tokenizing)...
After preprocessing: 449 training examples and 108 evaluation examples.
Starting DPO training loop...


Epoch 1/1 Training:   0%|          | 0/225 [00:00<?, ?it/s]You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Epoch 1/1 Training: 100%|██████████| 225/225 [30:44<00:00,  8.20s/it, loss=0.692, learning_rate=7.87e-9, global_step=57]


Epoch 1 finished. Average Training Loss: 0.6923264389567905


Epoch 1/1 Evaluation: 100%|██████████| 54/54 [06:23<00:00,  7.11s/it, eval_loss=0.693]


Epoch 1 finished. Average Evaluation Loss: 0.6926862034532759
Model saved to ./dpo_custom_tinyllama_ultrafeedback/final_checkpoint
DPO training complete!

Skipping model testing: CUDA not available. Run on GPU for testing.


# *DPO with ratings*

In [40]:

def preprocess_function(examples):
    processed = {
        "prompt_input_ids": [],
        "chosen_input_ids": [],
        "rejected_input_ids": [],
        "prompt_attention_mask": [],
        "chosen_attention_mask": [],
        "rejected_attention_mask": [],
        "prompt_len": [],
        "score_chosen": [],
        "score_rejected": []
    }

    for i in range(len(examples['prompt'])):
        current_prompt_messages = examples['prompt'][i]
        current_chosen_messages = examples['chosen'][i]
        current_rejected_messages = examples['rejected'][i]
        current_chosen_score = examples['score_chosen'][i]
        current_rejected_score = examples['score_rejected'][i]

        def ensure_message_list(messages, is_prompt=False, idx=i):
            if isinstance(messages, list) and all(isinstance(m, dict) and 'role' in m and 'content' in m for m in messages):
                return messages
            elif isinstance(messages, str):
                # If it's a string, try to wrap it as a simple user message.
                # This is a heuristic for malformed data; assumes the string is the user's input.
                if is_prompt:
                    print(f"Warning: Prompt entry {idx} is a string. Wrapping as 'user' message.")
                    return [{"role": "user", "content": messages}]
                else: # Chosen/rejected responses should not be simple strings
                    print(f"Warning: Chosen/Rejected entry {idx} is a string, which is unexpected. Skipping example.")
                    return None
            else:
                # If it's not a list or a string, it's an unrecognized format.
                print(f"Warning: Malformed entry {idx} (type: {type(messages)}). Skipping example.")
                return None
        

        current_prompt_messages = ensure_message_list(current_prompt_messages, is_prompt=True)
        current_chosen_messages = ensure_message_list(current_chosen_messages)
        current_rejected_messages = ensure_message_list(current_rejected_messages)

        if current_prompt_messages is None or current_chosen_messages is None or current_rejected_messages is None:
            continue 


        # Format prompt for DPO: add empty assistant turn
        prompt_with_assistant_turn = current_prompt_messages + [{"role": "assistant", "content": ""}]

        prompt_str = tokenizer.apply_chat_template(
            prompt_with_assistant_turn,
            tokenize=False,
            add_generation_prompt=True
        )
        
        # Format chosen and rejected responses (full conversation)
        chosen_str = tokenizer.apply_chat_template(
            current_chosen_messages,
            tokenize=False
        )
        rejected_str = tokenizer.apply_chat_template(
            current_rejected_messages,
            tokenize=False
        )

        # Tokenize (don't pad here, DataCollator will handle it)
        prompt_encoded = tokenizer(prompt_str, truncation=True, max_length=MAX_PROMPT_LENGTH)
        chosen_encoded = tokenizer(chosen_str, truncation=True, max_length=MAX_LENGTH)
        rejected_encoded = tokenizer(rejected_str, truncation=True, max_length=MAX_LENGTH)
        
        # Filter out examples that are too long after tokenization
        if (len(prompt_encoded['input_ids']) >= MAX_PROMPT_LENGTH or
            len(chosen_encoded['input_ids']) >= MAX_LENGTH or
            len(rejected_encoded['input_ids']) >= MAX_LENGTH):
            # print(f"Skipping example due to length: Prompt {len(prompt_encoded['input_ids'])}, Chosen {len(chosen_encoded['input_ids'])}, Rejected {len(rejected_encoded['input_ids'])}")
            continue

        processed["prompt_input_ids"].append(prompt_encoded["input_ids"])
        processed["chosen_input_ids"].append(chosen_encoded["input_ids"])
        processed["rejected_input_ids"].append(rejected_encoded["input_ids"])
        processed["prompt_attention_mask"].append(prompt_encoded["attention_mask"])
        processed["chosen_attention_mask"].append(chosen_encoded["attention_mask"])
        processed["rejected_attention_mask"].append(rejected_encoded["attention_mask"])
        processed["prompt_len"].append(len(prompt_encoded["input_ids"]))
        processed["score_chosen"].append(current_chosen_score)
        processed["score_rejected"].append(current_rejected_score)
    return processed


print("Preprocessing dataset (applying chat template and tokenizing)...")
train_dataset = train_dataset_raw.map(
    preprocess_function,
    batched=True,
    remove_columns=train_dataset_raw.column_names,
    num_proc=os.cpu_count(),
    desc="Preprocessing train dataset"
)
eval_dataset = eval_dataset_raw.map(
    preprocess_function,
    batched=True,
    remove_columns=eval_dataset_raw.column_names,
    num_proc=os.cpu_count(),
    desc="Preprocessing eval dataset"
)

# Convert lists to tensors for DataLoader
train_dataset.set_format(type="torch", columns=['prompt_input_ids', 'chosen_input_ids', 'rejected_input_ids',
                                                'prompt_attention_mask', 'chosen_attention_mask', 'rejected_attention_mask', 'prompt_len', 'score_chosen', 'score_rejected'])
eval_dataset.set_format(type="torch", columns=['prompt_input_ids', 'chosen_input_ids', 'rejected_input_ids',
                                               'prompt_attention_mask', 'chosen_attention_mask', 'rejected_attention_mask', 'prompt_len', 'score_chosen', 'score_rejected'])
print('Print score rejected')
print(train_dataset['score_rejected'])
print(f"After preprocessing: {len(train_dataset)} training examples and {len(eval_dataset)} evaluation examples.")

# Custom Data Collator for DPO
class DPODataCollator:
    def __init__(self, tokenizer, max_length=MAX_LENGTH, max_prompt_length=MAX_PROMPT_LENGTH):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.max_prompt_length = max_prompt_length

    def __call__(self, features):
        batch = {}
        for key in features[0].keys():
            batch[key] = [f[key] for f in features]

        # Pad sequences
        batch['prompt_input_ids'] = self.tokenizer.pad(
            {'input_ids': batch['prompt_input_ids'], 'attention_mask': batch['prompt_attention_mask']},
            padding='longest',
            max_length=self.max_prompt_length,
            return_tensors='pt',
        )['input_ids']
        batch['chosen_input_ids'] = self.tokenizer.pad(
            {'input_ids': batch['chosen_input_ids'], 'attention_mask': batch['chosen_attention_mask']},
            padding='longest',
            max_length=self.max_length,
            return_tensors='pt',
        )['input_ids']
        batch['rejected_input_ids'] = self.tokenizer.pad(
            {'input_ids': batch['rejected_input_ids'], 'attention_mask': batch['rejected_attention_mask']},
            padding='longest',
            max_length=self.max_length,
            return_tensors='pt',
        )['input_ids']

        # Also pad attention masks
        batch['prompt_attention_mask'] = self.tokenizer.pad(
            {'input_ids': [torch.ones(len(ids), dtype=torch.long) for ids in [f['prompt_input_ids'] for f in features]], 'attention_mask': batch['prompt_attention_mask']}, # Use actual lengths here
            padding='longest',
            max_length=self.max_prompt_length,
            return_tensors='pt',
        )['attention_mask']
        batch['chosen_attention_mask'] = self.tokenizer.pad(
            {'input_ids': [torch.ones(len(ids), dtype=torch.long) for ids in [f['chosen_input_ids'] for f in features]], 'attention_mask': batch['chosen_attention_mask']},
            padding='longest',
            max_length=self.max_length,
            return_tensors='pt',
        )['attention_mask']
        batch['rejected_attention_mask'] = self.tokenizer.pad(
            {'input_ids': [torch.ones(len(ids), dtype=torch.long) for ids in [f['rejected_input_ids'] for f in features]], 'attention_mask': batch['rejected_attention_mask']},
            padding='longest',
            max_length=self.max_length,
            return_tensors='pt',
        )['attention_mask']

        # Convert prompt_len to tensor
        batch['prompt_len'] = torch.tensor(batch['prompt_len'], dtype=torch.long)
        print(batch['score_chosen'])
        print(batch['score_rejected'])
        batch['score_chosen'] = torch.tensor(batch['score_chosen'], dtype=DTYPE)
        batch['score_rejected'] = torch.tensor(batch['score_rejected'], dtype=DTYPE)
        

        return batch

data_collator = DPODataCollator(tokenizer)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=data_collator
)
eval_dataloader = DataLoader(
    eval_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=data_collator
)

# --- 3. Helper Function to Calculate Log Probabilities ---
def get_log_probs(model, input_ids, attention_mask, prompt_len):
    """
    Calculates the log probability of a sequence of tokens given a model,
    masking out the prompt part and padding.

    Args:
        model: The language model (policy or reference).
        input_ids: Tensor of tokenized sequence (prompt + response).
        attention_mask: Tensor of attention mask for the sequence.
        prompt_len: Tensor of lengths of the prompt for each example in batch.

    Returns:
        A tensor of shape (batch_size,) containing the sum of log probabilities
        for the response tokens only.
    """
    with torch.no_grad() if model.training is False else torch.enable_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits # (batch_size, sequence_length, vocab_size)

    # Shift logits and labels for causal LM
    # The loss is computed for token_i given token_0 to token_{i-1}
    # So, logits[:, :-1, :] corresponds to predicting token_1 to token_{length-1}
    # And input_ids[:, 1:] are the actual token_1 to token_{length-1}
    logits = logits[:, :-1, :]
    labels = input_ids[:, 1:]
    
    # Calculate log_softmax over the vocabulary dimension
    log_probs = F.log_softmax(logits, dim=-1) # (batch_size, sequence_length - 1, vocab_size)

    # Gather the log probabilities for the actual next tokens
    # `labels.unsqueeze(-1)` makes it (batch_size, sequence_length - 1, 1)
    # `log_probs.gather(dim=-1, index=...)` picks the log_prob at the label index
    # `squeeze(-1)` removes the last dimension, resulting in (batch_size, sequence_length - 1)
    token_log_probs = torch.gather(log_probs, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)

    # Create a mask to only consider response tokens
    # tokens corresponding to prompt_len[i] up to the end of the sequence.
    # The mask needs to be shifted by 1 because `token_log_probs` is also shifted.
    sequence_lengths = attention_mask.sum(dim=-1) # Actual length of each sequence before padding
    
    # Create an index tensor for each position in the shifted sequence
    indices = torch.arange(token_log_probs.shape[1], device=token_log_probs.device).unsqueeze(0) # (1, seq_len-1)

    # Mask for response tokens: True if index >= prompt_len (shifted by 1) AND index < sequence_length (shifted by 1)
    response_mask = (indices >= (prompt_len - 1).unsqueeze(1)) & \
                    (indices < (sequence_lengths - 1).unsqueeze(1)) & \
                    (labels != tokenizer.pad_token_id) # Also exclude padding tokens explicitly

    # Apply the mask
    masked_log_probs = token_log_probs * response_mask.float()
    
    # Sum the log probabilities for each example
    return masked_log_probs.sum(dim=-1) # (batch_size,)

# --- 4. Optimizer and Scheduler ---
optimizer = AdamW(policy_model.parameters(), lr=LEARNING_RATE)

num_training_steps = (len(train_dataloader) // GRADIENT_ACCUMULATION_STEPS) * EPOCHS
lr_scheduler = get_scheduler(
    name="cosine",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

# --- 5. Training Loop ---
print("Starting DPO with ratings training loop...")
global_step = 0
policy_model.zero_grad()

for epoch in range(EPOCHS):
    policy_model.train() # Ensure policy model is in train mode
    total_loss = 0
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{EPOCHS} Training")

    for step, batch in enumerate(progress_bar):
        # Move batch to device
        batch = {k: v.to(DEVICE) for k, v in batch.items()}

        # Compute log probabilities for chosen responses
        log_prob_chosen_policy = get_log_probs(policy_model, batch['chosen_input_ids'], batch['chosen_attention_mask'], batch['prompt_len'])
        with torch.no_grad(): # Ensure no gradients for reference model
            log_prob_chosen_ref = get_log_probs(ref_model, batch['chosen_input_ids'], batch['chosen_attention_mask'], batch['prompt_len'])

        # Compute log probabilities for rejected responses
        log_prob_rejected_policy = get_log_probs(policy_model, batch['rejected_input_ids'], batch['rejected_attention_mask'], batch['prompt_len'])
        with torch.no_grad():
            log_prob_rejected_ref = get_log_probs(ref_model, batch['rejected_input_ids'], batch['rejected_attention_mask'], batch['prompt_len'])

        # Calculate the DPO loss components
        pi_log_ratio = log_prob_chosen_policy - log_prob_rejected_policy
        ref_log_ratio = log_prob_chosen_ref - log_prob_rejected_ref

        dpo_loss_components = -F.logsigmoid(BETA * (pi_log_ratio - ref_log_ratio - 1/BETA1*(batch['score_chosen'] - batch['score_rejected'])))
        
        # Average loss over the batch
        loss = dpo_loss_components.mean()
        
        # Backward pass with gradient accumulation
        loss = loss / GRADIENT_ACCUMULATION_STEPS
        loss.backward()

        total_loss += loss.item() * GRADIENT_ACCUMULATION_STEPS # Scale back up for logging

        if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0 or (step + 1) == len(train_dataloader):
            # Clip gradients to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(policy_model.parameters(), 1.0)
            
            optimizer.step()
            lr_scheduler.step()
            policy_model.zero_grad()
            global_step += 1
            
            progress_bar.set_postfix({
                "loss": total_loss / (step + 1),
                "learning_rate": lr_scheduler.get_last_lr()[0],
                "global_step": global_step
            })

    print(f"Epoch {epoch+1} finished. Average Training Loss: {total_loss / len(train_dataloader)}")

    # --- Evaluation ---
    policy_model.eval()
    eval_loss = 0
    eval_progress_bar = tqdm(eval_dataloader, desc=f"Epoch {epoch+1}/{EPOCHS} Evaluation")
    with torch.no_grad():
        for batch in eval_progress_bar:
            batch = {k: v.to(DEVICE) for k, v in batch.items()}

            log_prob_chosen_policy = get_log_probs(policy_model, batch['chosen_input_ids'], batch['chosen_attention_mask'], batch['prompt_len'])
            log_prob_chosen_ref = get_log_probs(ref_model, batch['chosen_input_ids'], batch['chosen_attention_mask'], batch['prompt_len'])

            log_prob_rejected_policy = get_log_probs(policy_model, batch['rejected_input_ids'], batch['rejected_attention_mask'], batch['prompt_len'])
            log_prob_rejected_ref = get_log_probs(ref_model, batch['rejected_input_ids'], batch['rejected_attention_mask'], batch['prompt_len'])

            pi_log_ratio = log_prob_chosen_policy - log_prob_rejected_policy
            ref_log_ratio = log_prob_chosen_ref - log_prob_rejected_ref

            dpo_loss_components = -F.logsigmoid(BETA * (pi_log_ratio - ref_log_ratio - 1/BETA1*(batch['score_chosen'] - batch['score_rejected'])))
            eval_loss += dpo_loss_components.mean().item()
            eval_progress_bar.set_postfix({"eval_loss": eval_loss / (eval_progress_bar.n + 1)})
            
    print(f"Epoch {epoch+1} finished. Average Evaluation Loss: {eval_loss / len(eval_dataloader)}")

# --- Save the fine-tuned model ---
# Save the LoRA adapter
final_model_path = os.path.join(OUTPUT_DIR, "final_checkpoint")
policy_model.save_pretrained(final_model_path)
tokenizer.save_pretrained(final_model_path)
print(f"Model saved to {final_model_path}")

print("DPO with ratings training complete!")

# --- Optional: Test the trained model ---
if DEVICE == "cuda":
    print("\n--- Testing the trained model ---")
    from transformers import pipeline
    from peft import PeftModel

    # Load the base model
    test_model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=DTYPE,
        device_map=DEVICE
    )
    # Load the LoRA adapter and merge
    test_model = PeftModel.from_pretrained(test_model, final_model_path)
    test_model = test_model.merge_and_unload() # Merge LoRA weights into base model
    test_model.eval()

    pipe = pipeline(
        "text-generation",
        model=test_model,
        tokenizer=tokenizer,
        torch_dtype=DTYPE,
        device=0
    )

    test_prompt_message = [{"role": "user", "content": "Write a short, heartwarming story about an old cat."}]
    test_prompt = tokenizer.apply_chat_template(
        test_prompt_message,
        tokenize=False,
        add_generation_prompt=True
    )

    print(f"Generating response for prompt:\n{test_prompt}")
    
    outputs = pipe(
        test_prompt,
        max_new_tokens=100,
        do_sample=True,
        temperature=0.7,
        top_k=50,
        top_p=0.95,
        repetition_penalty=1.1,
        eos_token_id=tokenizer.eos_token_id
    )
    print("\nGenerated Response (full):")
    print(outputs[0]['generated_text'])
    
    generated_text_only = outputs[0]['generated_text'].replace(test_prompt, '').strip()
    print("\nGenerated Response (clean):")
    print(generated_text_only)
else:
    print("\nSkipping model testing: CUDA not available. Run on GPU for testing.")

Preprocessing dataset (applying chat template and tokenizing)...
Print score rejected
tensor([ 7.0000,  8.5000,  7.0000,  6.0000,  3.0000,  8.0000,  7.0000,  7.5000,
         4.0000,  8.5000,  4.0000,  3.0000,  9.0000,  7.0000,  8.5000,  2.0000,
         6.0000,  5.0000,  3.0000,  7.0000,  3.0000,  8.0000,  3.0000,  8.0000,
         7.0000,  8.0000,  6.5000,  7.0000,  7.0000,  6.5000,  7.5000,  8.5000,
         3.0000,  7.0000,  8.0000,  6.0000,  8.0000,  3.0000,  7.0000,  7.5000,
         8.0000,  6.0000,  5.0000,  6.0000,  7.5000,  7.5000,  6.0000,  7.0000,
         8.5000,  7.0000,  7.0000,  5.0000,  4.0000,  8.0000,  7.5000,  8.5000,
         2.0000,  8.0000,  6.0000,  8.0000,  7.0000,  6.0000,  3.0000,  2.0000,
         6.0000,  3.0000,  8.0000,  6.0000,  7.5000,  8.0000,  7.0000,  6.5000,
         8.5000, 10.0000,  7.5000,  2.0000,  7.5000,  6.5000,  6.0000,  6.0000,
         2.0000,  6.0000,  8.5000,  7.0000,  7.5000,  7.5000,  6.0000,  7.0000,
         7.5000,  6.0000,  4.0000,

Epoch 1/1 Training:   0%|          | 0/225 [00:00<?, ?it/s]

[tensor(8.5000), tensor(8.)]
[tensor(6.), tensor(2.)]


Epoch 1/1 Training:   0%|          | 1/225 [00:03<14:16,  3.82s/it]

[tensor(7.), tensor(9.)]
[tensor(6.), tensor(9.)]


Epoch 1/1 Training:   1%|          | 2/225 [00:32<1:08:11, 18.35s/it]

[tensor(8.5000), tensor(9.)]
[tensor(8.), tensor(8.)]


Epoch 1/1 Training:   1%|▏         | 3/225 [00:35<42:48, 11.57s/it]  

[tensor(6.), tensor(8.)]
[tensor(2.), tensor(7.)]


Epoch 1/1 Training:   2%|▏         | 4/225 [00:39<30:24,  8.25s/it, loss=1.12, learning_rate=9.99e-6, global_step=1]

[tensor(8.5000), tensor(8.)]
[tensor(3.), tensor(7.)]


Epoch 1/1 Training:   2%|▏         | 5/225 [00:41<22:42,  6.19s/it, loss=1.12, learning_rate=9.99e-6, global_step=1]

[tensor(10.), tensor(8.5000)]
[tensor(8.5000), tensor(8.)]


Epoch 1/1 Training:   3%|▎         | 6/225 [01:10<51:11, 14.02s/it, loss=1.12, learning_rate=9.99e-6, global_step=1]

[tensor(8.5000), tensor(9.)]
[tensor(6.), tensor(7.)]


Epoch 1/1 Training:   3%|▎         | 7/225 [01:13<37:53, 10.43s/it, loss=1.12, learning_rate=9.99e-6, global_step=1]

[tensor(8.5000), tensor(8.5000)]
[tensor(8.), tensor(7.)]


Epoch 1/1 Training:   4%|▎         | 8/225 [01:31<46:28, 12.85s/it, loss=1.1, learning_rate=9.97e-6, global_step=2] 

[tensor(7.5000), tensor(8.)]
[tensor(7.), tensor(7.)]


Epoch 1/1 Training:   4%|▍         | 9/225 [01:34<35:20,  9.82s/it, loss=1.1, learning_rate=9.97e-6, global_step=2]

[tensor(8.5000), tensor(6.)]
[tensor(8.5000), tensor(4.)]


Epoch 1/1 Training:   4%|▍         | 10/225 [01:37<27:38,  7.71s/it, loss=1.1, learning_rate=9.97e-6, global_step=2]

[tensor(9.), tensor(7.5000)]
[tensor(8.5000), tensor(6.5000)]


Epoch 1/1 Training:   5%|▍         | 11/225 [01:40<21:57,  6.16s/it, loss=1.1, learning_rate=9.97e-6, global_step=2]

[tensor(8.), tensor(8.5000)]
[tensor(6.), tensor(7.)]


Epoch 1/1 Training:   5%|▌         | 12/225 [01:44<18:51,  5.31s/it, loss=1.03, learning_rate=9.93e-6, global_step=3]

[tensor(8.), tensor(8.5000)]
[tensor(7.5000), tensor(6.)]


Epoch 1/1 Training:   6%|▌         | 13/225 [01:46<16:03,  4.55s/it, loss=1.03, learning_rate=9.93e-6, global_step=3]

[tensor(9.), tensor(9.)]
[tensor(3.), tensor(3.)]


Epoch 1/1 Training:   6%|▌         | 14/225 [01:49<14:05,  4.01s/it, loss=1.03, learning_rate=9.93e-6, global_step=3]

[tensor(9.), tensor(8.5000)]
[tensor(4.), tensor(6.)]


Epoch 1/1 Training:   7%|▋         | 15/225 [01:52<12:35,  3.60s/it, loss=1.03, learning_rate=9.93e-6, global_step=3]

[tensor(9.), tensor(8.)]
[tensor(7.), tensor(7.)]


Epoch 1/1 Training:   7%|▋         | 16/225 [01:55<12:12,  3.50s/it, loss=1.12, learning_rate=9.87e-6, global_step=4]

[tensor(8.5000), tensor(9.)]
[tensor(8.), tensor(8.)]


Epoch 1/1 Training:   8%|▊         | 17/225 [01:57<11:03,  3.19s/it, loss=1.12, learning_rate=9.87e-6, global_step=4]

[tensor(8.5000), tensor(9.)]
[tensor(6.5000), tensor(5.)]


Epoch 1/1 Training:   8%|▊         | 18/225 [02:14<24:30,  7.11s/it, loss=1.12, learning_rate=9.87e-6, global_step=4]

[tensor(7.), tensor(8.)]
[tensor(4.), tensor(4.)]


Epoch 1/1 Training:   8%|▊         | 19/225 [02:18<21:19,  6.21s/it, loss=1.12, learning_rate=9.87e-6, global_step=4]

[tensor(8.5000), tensor(7.)]
[tensor(8.), tensor(5.)]


Epoch 1/1 Training:   9%|▉         | 20/225 [02:21<18:24,  5.39s/it, loss=1.12, learning_rate=9.8e-6, global_step=5] 

[tensor(6.5000), tensor(8.)]
[tensor(4.), tensor(6.5000)]


Epoch 1/1 Training:   9%|▉         | 21/225 [02:24<15:37,  4.59s/it, loss=1.12, learning_rate=9.8e-6, global_step=5]

[tensor(9.), tensor(8.)]
[tensor(7.), tensor(7.)]


Epoch 1/1 Training:  10%|▉         | 22/225 [02:27<13:34,  4.01s/it, loss=1.12, learning_rate=9.8e-6, global_step=5]

[tensor(7.), tensor(8.)]
[tensor(6.), tensor(6.)]


Epoch 1/1 Training:  10%|█         | 23/225 [02:29<12:13,  3.63s/it, loss=1.12, learning_rate=9.8e-6, global_step=5]

[tensor(8.), tensor(7.5000)]
[tensor(3.), tensor(3.)]


Epoch 1/1 Training:  11%|█         | 24/225 [02:33<12:07,  3.62s/it, loss=1.14, learning_rate=9.72e-6, global_step=6]

[tensor(9.), tensor(6.)]
[tensor(8.), tensor(6.)]


Epoch 1/1 Training:  11%|█         | 25/225 [02:36<11:21,  3.41s/it, loss=1.14, learning_rate=9.72e-6, global_step=6]

[tensor(8.), tensor(9.)]
[tensor(7.5000), tensor(6.)]


Epoch 1/1 Training:  12%|█▏        | 26/225 [02:39<10:48,  3.26s/it, loss=1.14, learning_rate=9.72e-6, global_step=6]

[tensor(7.5000), tensor(9.)]
[tensor(2.), tensor(7.5000)]


Epoch 1/1 Training:  12%|█▏        | 27/225 [02:42<10:15,  3.11s/it, loss=1.14, learning_rate=9.72e-6, global_step=6]

[tensor(7.), tensor(8.5000)]
[tensor(6.), tensor(1.)]


Epoch 1/1 Training:  12%|█▏        | 28/225 [02:44<09:57,  3.03s/it, loss=1.15, learning_rate=9.62e-6, global_step=7]

[tensor(8.5000), tensor(8.5000)]
[tensor(8.), tensor(2.)]


Epoch 1/1 Training:  13%|█▎        | 29/225 [02:47<09:33,  2.93s/it, loss=1.15, learning_rate=9.62e-6, global_step=7]

[tensor(9.), tensor(8.5000)]
[tensor(3.), tensor(6.)]


Epoch 1/1 Training:  13%|█▎        | 30/225 [02:50<09:10,  2.82s/it, loss=1.15, learning_rate=9.62e-6, global_step=7]

[tensor(9.), tensor(9.)]
[tensor(8.5000), tensor(4.)]


Epoch 1/1 Training:  14%|█▍        | 31/225 [02:52<08:46,  2.72s/it, loss=1.15, learning_rate=9.62e-6, global_step=7]

[tensor(8.5000), tensor(7.)]
[tensor(8.), tensor(4.)]


Epoch 1/1 Training:  14%|█▍        | 32/225 [02:55<09:01,  2.80s/it, loss=1.18, learning_rate=9.5e-6, global_step=8] 

[tensor(7.), tensor(9.)]
[tensor(6.), tensor(7.5000)]


Epoch 1/1 Training:  15%|█▍        | 33/225 [02:58<09:05,  2.84s/it, loss=1.18, learning_rate=9.5e-6, global_step=8]

[tensor(7.5000), tensor(7.5000)]
[tensor(4.), tensor(4.)]


Epoch 1/1 Training:  15%|█▌        | 34/225 [03:01<09:05,  2.86s/it, loss=1.18, learning_rate=9.5e-6, global_step=8]

[tensor(7.), tensor(9.)]
[tensor(4.), tensor(8.)]


Epoch 1/1 Training:  16%|█▌        | 35/225 [03:04<08:51,  2.80s/it, loss=1.18, learning_rate=9.5e-6, global_step=8]

[tensor(8.5000), tensor(9.)]
[tensor(8.), tensor(7.)]


Epoch 1/1 Training:  16%|█▌        | 36/225 [03:07<09:10,  2.91s/it, loss=1.17, learning_rate=9.38e-6, global_step=9]

[tensor(8.5000), tensor(8.)]
[tensor(3.), tensor(8.)]


Epoch 1/1 Training:  16%|█▋        | 37/225 [03:09<08:39,  2.76s/it, loss=1.17, learning_rate=9.38e-6, global_step=9]

[tensor(8.5000), tensor(8.)]
[tensor(7.), tensor(7.)]


Epoch 1/1 Training:  17%|█▋        | 38/225 [03:17<13:01,  4.18s/it, loss=1.17, learning_rate=9.38e-6, global_step=9]

[tensor(9.), tensor(7.5000)]
[tensor(7.5000), tensor(6.)]


Epoch 1/1 Training:  17%|█▋        | 39/225 [03:20<11:53,  3.84s/it, loss=1.17, learning_rate=9.38e-6, global_step=9]

[tensor(8.), tensor(8.)]
[tensor(2.), tensor(7.5000)]


Epoch 1/1 Training:  18%|█▊        | 40/225 [03:23<11:34,  3.75s/it, loss=1.17, learning_rate=9.23e-6, global_step=10]

[tensor(9.), tensor(8.)]
[tensor(4.), tensor(6.5000)]


Epoch 1/1 Training:  18%|█▊        | 41/225 [03:26<10:50,  3.54s/it, loss=1.17, learning_rate=9.23e-6, global_step=10]

[tensor(8.), tensor(9.)]
[tensor(2.), tensor(8.)]


Epoch 1/1 Training:  19%|█▊        | 42/225 [03:44<23:33,  7.72s/it, loss=1.17, learning_rate=9.23e-6, global_step=10]

[tensor(8.), tensor(8.5000)]
[tensor(7.), tensor(3.)]


Epoch 1/1 Training:  19%|█▉        | 43/225 [03:47<19:13,  6.34s/it, loss=1.17, learning_rate=9.23e-6, global_step=10]

[tensor(8.5000), tensor(8.5000)]
[tensor(7.5000), tensor(7.5000)]


Epoch 1/1 Training:  20%|█▉        | 44/225 [03:51<16:40,  5.53s/it, loss=1.18, learning_rate=9.08e-6, global_step=11]

[tensor(7.), tensor(8.)]
[tensor(3.), tensor(7.5000)]


Epoch 1/1 Training:  20%|██        | 45/225 [04:07<26:42,  8.90s/it, loss=1.18, learning_rate=9.08e-6, global_step=11]

[tensor(8.5000), tensor(8.5000)]
[tensor(8.5000), tensor(8.5000)]


Epoch 1/1 Training:  20%|██        | 46/225 [04:10<21:15,  7.12s/it, loss=1.18, learning_rate=9.08e-6, global_step=11]

[tensor(7.), tensor(8.5000)]
[tensor(6.), tensor(7.)]


Epoch 1/1 Training:  21%|██        | 47/225 [04:13<17:20,  5.85s/it, loss=1.18, learning_rate=9.08e-6, global_step=11]

[tensor(8.5000), tensor(8.5000)]
[tensor(7.), tensor(7.5000)]


Epoch 1/1 Training:  21%|██▏       | 48/225 [04:16<14:46,  5.01s/it, loss=1.16, learning_rate=8.91e-6, global_step=12]

[tensor(7.), tensor(9.)]
[tensor(6.), tensor(6.)]


Epoch 1/1 Training:  22%|██▏       | 49/225 [04:19<12:33,  4.28s/it, loss=1.16, learning_rate=8.91e-6, global_step=12]

[tensor(8.), tensor(7.)]
[tensor(7.), tensor(6.5000)]


Epoch 1/1 Training:  22%|██▏       | 50/225 [04:22<11:14,  3.86s/it, loss=1.16, learning_rate=8.91e-6, global_step=12]

[tensor(8.), tensor(7.5000)]
[tensor(6.), tensor(7.)]


Epoch 1/1 Training:  23%|██▎       | 51/225 [04:24<10:03,  3.47s/it, loss=1.16, learning_rate=8.91e-6, global_step=12]

[tensor(8.), tensor(8.5000)]
[tensor(6.), tensor(6.)]


Epoch 1/1 Training:  23%|██▎       | 52/225 [04:28<09:50,  3.42s/it, loss=1.15, learning_rate=8.73e-6, global_step=13]

[tensor(9.), tensor(9.)]
[tensor(8.5000), tensor(9.)]


Epoch 1/1 Training:  24%|██▎       | 53/225 [04:31<09:37,  3.36s/it, loss=1.15, learning_rate=8.73e-6, global_step=13]

[tensor(6.), tensor(7.)]
[tensor(4.), tensor(7.)]


Epoch 1/1 Training:  24%|██▍       | 54/225 [04:33<08:57,  3.14s/it, loss=1.15, learning_rate=8.73e-6, global_step=13]

[tensor(8.), tensor(9.)]
[tensor(7.), tensor(6.)]


Epoch 1/1 Training:  24%|██▍       | 55/225 [04:50<20:01,  7.07s/it, loss=1.15, learning_rate=8.73e-6, global_step=13]

[tensor(8.), tensor(8.5000)]
[tensor(7.5000), tensor(6.)]


Epoch 1/1 Training:  25%|██▍       | 56/225 [04:53<16:44,  5.95s/it, loss=1.13, learning_rate=8.54e-6, global_step=14]

[tensor(7.), tensor(3.)]
[tensor(6.5000), tensor(3.)]


Epoch 1/1 Training:  25%|██▌       | 57/225 [04:56<14:00,  5.00s/it, loss=1.13, learning_rate=8.54e-6, global_step=14]

[tensor(8.), tensor(8.5000)]
[tensor(7.), tensor(7.)]


Epoch 1/1 Training:  26%|██▌       | 58/225 [04:59<12:04,  4.34s/it, loss=1.13, learning_rate=8.54e-6, global_step=14]

[tensor(7.), tensor(9.5000)]
[tensor(1.), tensor(4.)]


Epoch 1/1 Training:  26%|██▌       | 59/225 [05:15<22:26,  8.11s/it, loss=1.13, learning_rate=8.54e-6, global_step=14]

[tensor(8.5000), tensor(8.5000)]
[tensor(7.), tensor(7.)]


Epoch 1/1 Training:  27%|██▋       | 60/225 [05:19<18:23,  6.69s/it, loss=1.14, learning_rate=8.33e-6, global_step=15]

[tensor(8.), tensor(7.5000)]
[tensor(6.), tensor(3.)]


Epoch 1/1 Training:  27%|██▋       | 61/225 [05:21<14:53,  5.45s/it, loss=1.14, learning_rate=8.33e-6, global_step=15]

[tensor(8.5000), tensor(8.)]
[tensor(7.), tensor(5.)]


Epoch 1/1 Training:  28%|██▊       | 62/225 [05:24<12:31,  4.61s/it, loss=1.14, learning_rate=8.33e-6, global_step=15]

[tensor(8.), tensor(8.)]
[tensor(5.), tensor(7.)]


Epoch 1/1 Training:  28%|██▊       | 63/225 [05:41<22:18,  8.27s/it, loss=1.14, learning_rate=8.33e-6, global_step=15]

[tensor(8.5000), tensor(8.5000)]
[tensor(8.5000), tensor(7.)]


Epoch 1/1 Training:  28%|██▊       | 64/225 [05:44<18:15,  6.80s/it, loss=1.13, learning_rate=8.12e-6, global_step=16]

[tensor(8.), tensor(8.5000)]
[tensor(7.), tensor(8.)]


Epoch 1/1 Training:  29%|██▉       | 65/225 [06:03<27:20, 10.25s/it, loss=1.13, learning_rate=8.12e-6, global_step=16]

[tensor(8.), tensor(8.5000)]
[tensor(7.), tensor(7.)]


Epoch 1/1 Training:  29%|██▉       | 66/225 [06:05<21:05,  7.96s/it, loss=1.13, learning_rate=8.12e-6, global_step=16]

[tensor(7.), tensor(6.5000)]
[tensor(2.), tensor(6.)]


Epoch 1/1 Training:  30%|██▉       | 67/225 [06:08<16:41,  6.34s/it, loss=1.13, learning_rate=8.12e-6, global_step=16]

[tensor(8.), tensor(8.5000)]
[tensor(7.), tensor(7.)]


Epoch 1/1 Training:  30%|███       | 68/225 [06:11<14:02,  5.37s/it, loss=1.13, learning_rate=7.89e-6, global_step=17]

[tensor(7.), tensor(8.5000)]
[tensor(4.), tensor(7.5000)]


Epoch 1/1 Training:  31%|███       | 69/225 [06:14<12:06,  4.65s/it, loss=1.13, learning_rate=7.89e-6, global_step=17]

[tensor(8.5000), tensor(6.5000)]
[tensor(3.), tensor(6.)]


Epoch 1/1 Training:  31%|███       | 70/225 [06:16<10:28,  4.06s/it, loss=1.13, learning_rate=7.89e-6, global_step=17]

[tensor(7.), tensor(8.5000)]
[tensor(6.), tensor(7.5000)]


Epoch 1/1 Training:  32%|███▏      | 71/225 [06:19<09:12,  3.59s/it, loss=1.13, learning_rate=7.89e-6, global_step=17]

[tensor(8.), tensor(10.)]
[tensor(7.5000), tensor(2.)]


Epoch 1/1 Training:  32%|███▏      | 72/225 [06:35<19:02,  7.47s/it, loss=1.13, learning_rate=7.66e-6, global_step=18]

[tensor(7.5000), tensor(8.5000)]
[tensor(4.), tensor(6.)]


Epoch 1/1 Training:  32%|███▏      | 73/225 [06:39<15:52,  6.27s/it, loss=1.13, learning_rate=7.66e-6, global_step=18]

[tensor(8.5000), tensor(8.)]
[tensor(6.), tensor(3.)]


Epoch 1/1 Training:  33%|███▎      | 74/225 [06:41<12:55,  5.14s/it, loss=1.13, learning_rate=7.66e-6, global_step=18]

[tensor(8.5000), tensor(7.5000)]
[tensor(8.), tensor(6.)]


Epoch 1/1 Training:  33%|███▎      | 75/225 [06:58<21:25,  8.57s/it, loss=1.13, learning_rate=7.66e-6, global_step=18]

[tensor(8.), tensor(9.)]
[tensor(7.5000), tensor(3.)]


Epoch 1/1 Training:  34%|███▍      | 76/225 [07:01<17:15,  6.95s/it, loss=1.14, learning_rate=7.42e-6, global_step=19]

[tensor(7.5000), tensor(8.)]
[tensor(6.), tensor(6.5000)]


Epoch 1/1 Training:  34%|███▍      | 77/225 [07:04<14:07,  5.72s/it, loss=1.14, learning_rate=7.42e-6, global_step=19]

[tensor(6.5000), tensor(7.)]
[tensor(2.), tensor(5.)]


Epoch 1/1 Training:  35%|███▍      | 78/225 [07:22<23:11,  9.46s/it, loss=1.14, learning_rate=7.42e-6, global_step=19]

[tensor(8.5000), tensor(8.5000)]
[tensor(5.), tensor(8.)]


Epoch 1/1 Training:  35%|███▌      | 79/225 [07:25<18:15,  7.50s/it, loss=1.14, learning_rate=7.42e-6, global_step=19]

[tensor(8.5000), tensor(7.5000)]
[tensor(4.), tensor(3.)]


Epoch 1/1 Training:  36%|███▌      | 80/225 [07:28<14:54,  6.17s/it, loss=1.15, learning_rate=7.17e-6, global_step=20]

[tensor(9.), tensor(8.5000)]
[tensor(7.5000), tensor(7.5000)]


Epoch 1/1 Training:  36%|███▌      | 81/225 [07:32<12:54,  5.38s/it, loss=1.15, learning_rate=7.17e-6, global_step=20]

[tensor(8.5000), tensor(8.)]
[tensor(7.5000), tensor(7.)]


Epoch 1/1 Training:  36%|███▋      | 82/225 [07:35<11:02,  4.64s/it, loss=1.15, learning_rate=7.17e-6, global_step=20]

[tensor(9.), tensor(4.)]
[tensor(8.), tensor(3.)]


Epoch 1/1 Training:  37%|███▋      | 83/225 [07:37<09:38,  4.08s/it, loss=1.15, learning_rate=7.17e-6, global_step=20]

[tensor(10.), tensor(8.)]
[tensor(10.), tensor(7.5000)]


Epoch 1/1 Training:  37%|███▋      | 84/225 [07:41<08:53,  3.79s/it, loss=1.14, learning_rate=6.91e-6, global_step=21]

[tensor(9.), tensor(8.5000)]
[tensor(8.5000), tensor(6.)]


Epoch 1/1 Training:  38%|███▊      | 85/225 [07:58<18:15,  7.83s/it, loss=1.14, learning_rate=6.91e-6, global_step=21]

[tensor(8.5000), tensor(8.5000)]
[tensor(7.5000), tensor(6.5000)]


Epoch 1/1 Training:  38%|███▊      | 86/225 [08:01<14:54,  6.43s/it, loss=1.14, learning_rate=6.91e-6, global_step=21]

[tensor(8.), tensor(9.)]
[tensor(8.), tensor(8.5000)]


Epoch 1/1 Training:  39%|███▊      | 87/225 [08:04<12:17,  5.34s/it, loss=1.14, learning_rate=6.91e-6, global_step=21]

[tensor(8.5000), tensor(8.5000)]
[tensor(7.), tensor(2.)]


Epoch 1/1 Training:  39%|███▉      | 88/225 [08:07<10:37,  4.66s/it, loss=1.13, learning_rate=6.65e-6, global_step=22]

[tensor(8.5000), tensor(8.)]
[tensor(7.5000), tensor(7.5000)]


Epoch 1/1 Training:  40%|███▉      | 89/225 [08:09<09:08,  4.03s/it, loss=1.13, learning_rate=6.65e-6, global_step=22]

[tensor(8.), tensor(8.)]
[tensor(3.), tensor(7.)]


Epoch 1/1 Training:  40%|████      | 90/225 [08:12<08:21,  3.72s/it, loss=1.13, learning_rate=6.65e-6, global_step=22]

[tensor(7.5000), tensor(7.5000)]
[tensor(3.), tensor(6.)]


Epoch 1/1 Training:  40%|████      | 91/225 [08:15<07:34,  3.39s/it, loss=1.13, learning_rate=6.65e-6, global_step=22]

[tensor(9.), tensor(9.)]
[tensor(8.5000), tensor(7.)]


Epoch 1/1 Training:  41%|████      | 92/225 [08:18<07:15,  3.28s/it, loss=1.13, learning_rate=6.38e-6, global_step=23]

[tensor(9.), tensor(4.)]
[tensor(8.5000), tensor(4.)]


Epoch 1/1 Training:  41%|████▏     | 93/225 [08:21<06:47,  3.09s/it, loss=1.13, learning_rate=6.38e-6, global_step=23]

[tensor(8.), tensor(9.)]
[tensor(7.5000), tensor(2.)]


Epoch 1/1 Training:  42%|████▏     | 94/225 [08:24<06:35,  3.02s/it, loss=1.13, learning_rate=6.38e-6, global_step=23]

[tensor(7.5000), tensor(7.5000)]
[tensor(7.), tensor(6.)]


Epoch 1/1 Training:  42%|████▏     | 95/225 [08:26<06:23,  2.95s/it, loss=1.13, learning_rate=6.38e-6, global_step=23]

[tensor(9.), tensor(7.5000)]
[tensor(7.), tensor(6.)]


Epoch 1/1 Training:  43%|████▎     | 96/225 [08:29<06:22,  2.97s/it, loss=1.13, learning_rate=6.11e-6, global_step=24]

[tensor(8.5000), tensor(8.5000)]
[tensor(8.), tensor(7.5000)]


Epoch 1/1 Training:  43%|████▎     | 97/225 [08:32<06:09,  2.89s/it, loss=1.13, learning_rate=6.11e-6, global_step=24]

[tensor(7.5000), tensor(10.)]
[tensor(3.), tensor(6.)]


Epoch 1/1 Training:  44%|████▎     | 98/225 [08:51<16:36,  7.84s/it, loss=1.13, learning_rate=6.11e-6, global_step=24]

[tensor(9.), tensor(6.)]
[tensor(8.5000), tensor(2.)]


Epoch 1/1 Training:  44%|████▍     | 99/225 [08:56<14:15,  6.79s/it, loss=1.13, learning_rate=6.11e-6, global_step=24]

[tensor(8.), tensor(8.5000)]
[tensor(7.), tensor(7.)]


Epoch 1/1 Training:  44%|████▍     | 100/225 [09:00<12:13,  5.87s/it, loss=1.13, learning_rate=5.84e-6, global_step=25]

[tensor(9.), tensor(8.5000)]
[tensor(8.), tensor(5.)]


Epoch 1/1 Training:  45%|████▍     | 101/225 [09:02<10:13,  4.95s/it, loss=1.13, learning_rate=5.84e-6, global_step=25]

[tensor(8.), tensor(8.5000)]
[tensor(7.), tensor(7.)]


Epoch 1/1 Training:  45%|████▌     | 102/225 [09:05<08:53,  4.34s/it, loss=1.13, learning_rate=5.84e-6, global_step=25]

[tensor(7.5000), tensor(8.)]
[tensor(3.), tensor(5.)]


Epoch 1/1 Training:  46%|████▌     | 103/225 [09:23<17:05,  8.40s/it, loss=1.13, learning_rate=5.84e-6, global_step=25]

[tensor(8.5000), tensor(8.)]
[tensor(8.5000), tensor(8.)]


Epoch 1/1 Training:  46%|████▌     | 104/225 [09:26<13:47,  6.84s/it, loss=1.13, learning_rate=5.56e-6, global_step=26]

[tensor(8.5000), tensor(9.5000)]
[tensor(3.), tensor(8.5000)]


Epoch 1/1 Training:  47%|████▋     | 105/225 [09:43<19:26,  9.72s/it, loss=1.13, learning_rate=5.56e-6, global_step=26]

[tensor(8.5000), tensor(8.5000)]
[tensor(7.5000), tensor(8.)]


Epoch 1/1 Training:  47%|████▋     | 106/225 [09:46<15:40,  7.90s/it, loss=1.13, learning_rate=5.56e-6, global_step=26]

[tensor(8.5000), tensor(8.)]
[tensor(7.), tensor(7.)]


Epoch 1/1 Training:  48%|████▊     | 107/225 [09:49<12:38,  6.43s/it, loss=1.13, learning_rate=5.56e-6, global_step=26]

[tensor(8.5000), tensor(7.5000)]
[tensor(8.), tensor(7.)]


Epoch 1/1 Training:  48%|████▊     | 108/225 [09:53<10:37,  5.45s/it, loss=1.12, learning_rate=5.28e-6, global_step=27]

[tensor(6.), tensor(8.)]
[tensor(3.), tensor(4.)]


Epoch 1/1 Training:  48%|████▊     | 109/225 [09:59<11:22,  5.88s/it, loss=1.12, learning_rate=5.28e-6, global_step=27]

[tensor(8.5000), tensor(5.)]
[tensor(7.), tensor(2.)]


Epoch 1/1 Training:  49%|████▉     | 110/225 [10:02<09:37,  5.02s/it, loss=1.12, learning_rate=5.28e-6, global_step=27]

[tensor(9.), tensor(8.)]
[tensor(2.), tensor(7.)]


Epoch 1/1 Training:  49%|████▉     | 111/225 [10:20<16:34,  8.72s/it, loss=1.12, learning_rate=5.28e-6, global_step=27]

[tensor(8.5000), tensor(7.5000)]
[tensor(2.), tensor(6.5000)]


Epoch 1/1 Training:  50%|████▉     | 112/225 [10:23<13:18,  7.07s/it, loss=1.13, learning_rate=5e-6, global_step=28]   

[tensor(8.5000), tensor(7.5000)]
[tensor(7.), tensor(7.5000)]


Epoch 1/1 Training:  50%|█████     | 113/225 [10:42<19:51, 10.64s/it, loss=1.13, learning_rate=5e-6, global_step=28]

[tensor(8.), tensor(8.5000)]
[tensor(6.), tensor(6.)]


Epoch 1/1 Training:  51%|█████     | 114/225 [10:45<15:19,  8.29s/it, loss=1.13, learning_rate=5e-6, global_step=28]

[tensor(8.5000), tensor(7.5000)]
[tensor(7.), tensor(6.)]


Epoch 1/1 Training:  51%|█████     | 115/225 [10:48<12:11,  6.65s/it, loss=1.13, learning_rate=5e-6, global_step=28]

[tensor(9.), tensor(7.5000)]
[tensor(7.5000), tensor(5.)]


Epoch 1/1 Training:  52%|█████▏    | 116/225 [10:51<10:12,  5.62s/it, loss=1.13, learning_rate=4.72e-6, global_step=29]

[tensor(7.), tensor(8.)]
[tensor(6.), tensor(5.)]


Epoch 1/1 Training:  52%|█████▏    | 117/225 [10:54<08:42,  4.83s/it, loss=1.13, learning_rate=4.72e-6, global_step=29]

[tensor(7.5000), tensor(8.)]
[tensor(7.), tensor(3.)]


Epoch 1/1 Training:  52%|█████▏    | 118/225 [10:57<07:44,  4.34s/it, loss=1.13, learning_rate=4.72e-6, global_step=29]

[tensor(8.5000), tensor(8.5000)]
[tensor(5.), tensor(8.)]


Epoch 1/1 Training:  53%|█████▎    | 119/225 [11:00<07:06,  4.02s/it, loss=1.13, learning_rate=4.72e-6, global_step=29]

[tensor(8.5000), tensor(8.)]
[tensor(6.5000), tensor(7.5000)]


Epoch 1/1 Training:  53%|█████▎    | 120/225 [11:19<14:54,  8.52s/it, loss=1.13, learning_rate=4.44e-6, global_step=30]

[tensor(8.), tensor(8.5000)]
[tensor(7.), tensor(8.)]


Epoch 1/1 Training:  54%|█████▍    | 121/225 [11:22<11:54,  6.87s/it, loss=1.13, learning_rate=4.44e-6, global_step=30]

[tensor(7.), tensor(7.5000)]
[tensor(5.), tensor(7.)]


Epoch 1/1 Training:  54%|█████▍    | 122/225 [11:25<09:47,  5.70s/it, loss=1.13, learning_rate=4.44e-6, global_step=30]

[tensor(8.5000), tensor(8.5000)]
[tensor(5.), tensor(7.5000)]


Epoch 1/1 Training:  55%|█████▍    | 123/225 [11:28<08:14,  4.85s/it, loss=1.13, learning_rate=4.44e-6, global_step=30]

[tensor(7.), tensor(6.)]
[tensor(5.), tensor(3.)]


Epoch 1/1 Training:  55%|█████▌    | 124/225 [11:31<07:14,  4.30s/it, loss=1.13, learning_rate=4.16e-6, global_step=31]

[tensor(6.), tensor(9.)]
[tensor(3.), tensor(7.5000)]


Epoch 1/1 Training:  56%|█████▌    | 125/225 [11:34<06:24,  3.85s/it, loss=1.13, learning_rate=4.16e-6, global_step=31]

[tensor(8.), tensor(8.5000)]
[tensor(7.), tensor(7.)]


Epoch 1/1 Training:  56%|█████▌    | 126/225 [11:37<05:57,  3.61s/it, loss=1.13, learning_rate=4.16e-6, global_step=31]

[tensor(9.), tensor(8.)]
[tensor(6.), tensor(7.)]


Epoch 1/1 Training:  56%|█████▋    | 127/225 [11:40<05:31,  3.38s/it, loss=1.13, learning_rate=4.16e-6, global_step=31]

[tensor(8.5000), tensor(10.)]
[tensor(3.), tensor(6.5000)]


Epoch 1/1 Training:  57%|█████▋    | 128/225 [11:43<05:24,  3.34s/it, loss=1.13, learning_rate=3.89e-6, global_step=32]

[tensor(4.), tensor(9.5000)]
[tensor(2.), tensor(7.5000)]


Epoch 1/1 Training:  57%|█████▋    | 129/225 [11:46<05:12,  3.25s/it, loss=1.13, learning_rate=3.89e-6, global_step=32]

[tensor(8.5000), tensor(8.5000)]
[tensor(7.5000), tensor(8.)]


Epoch 1/1 Training:  58%|█████▊    | 130/225 [11:49<04:54,  3.10s/it, loss=1.13, learning_rate=3.89e-6, global_step=32]

[tensor(7.5000), tensor(7.5000)]
[tensor(3.), tensor(5.)]


Epoch 1/1 Training:  58%|█████▊    | 131/225 [11:52<04:47,  3.05s/it, loss=1.13, learning_rate=3.89e-6, global_step=32]

[tensor(7.5000), tensor(8.5000)]
[tensor(7.), tensor(8.5000)]


Epoch 1/1 Training:  59%|█████▊    | 132/225 [11:55<04:49,  3.11s/it, loss=1.13, learning_rate=3.62e-6, global_step=33]

[tensor(7.5000), tensor(7.)]
[tensor(3.), tensor(7.)]


Epoch 1/1 Training:  59%|█████▉    | 133/225 [11:58<04:32,  2.97s/it, loss=1.13, learning_rate=3.62e-6, global_step=33]

[tensor(8.5000), tensor(8.5000)]
[tensor(8.), tensor(3.)]


Epoch 1/1 Training:  60%|█████▉    | 134/225 [12:01<04:29,  2.96s/it, loss=1.13, learning_rate=3.62e-6, global_step=33]

[tensor(8.5000), tensor(8.)]
[tensor(7.5000), tensor(2.)]


Epoch 1/1 Training:  60%|██████    | 135/225 [12:04<04:22,  2.91s/it, loss=1.13, learning_rate=3.62e-6, global_step=33]

[tensor(8.5000), tensor(8.)]
[tensor(7.), tensor(5.)]


Epoch 1/1 Training:  60%|██████    | 136/225 [12:07<04:27,  3.00s/it, loss=1.13, learning_rate=3.35e-6, global_step=34]

[tensor(8.5000), tensor(7.5000)]
[tensor(7.), tensor(6.5000)]


Epoch 1/1 Training:  61%|██████    | 137/225 [12:09<04:15,  2.90s/it, loss=1.13, learning_rate=3.35e-6, global_step=34]

[tensor(8.5000), tensor(9.)]
[tensor(7.5000), tensor(8.)]


Epoch 1/1 Training:  61%|██████▏   | 138/225 [12:12<04:13,  2.91s/it, loss=1.13, learning_rate=3.35e-6, global_step=34]

[tensor(8.5000), tensor(9.)]
[tensor(3.), tensor(4.)]


Epoch 1/1 Training:  62%|██████▏   | 139/225 [12:15<04:00,  2.80s/it, loss=1.13, learning_rate=3.35e-6, global_step=34]

[tensor(7.5000), tensor(8.)]
[tensor(6.), tensor(7.5000)]


Epoch 1/1 Training:  62%|██████▏   | 140/225 [12:18<04:15,  3.00s/it, loss=1.13, learning_rate=3.09e-6, global_step=35]

[tensor(8.5000), tensor(7.)]
[tensor(8.), tensor(5.)]


Epoch 1/1 Training:  63%|██████▎   | 141/225 [12:21<04:15,  3.04s/it, loss=1.13, learning_rate=3.09e-6, global_step=35]

[tensor(8.), tensor(7.5000)]
[tensor(7.5000), tensor(7.)]


Epoch 1/1 Training:  63%|██████▎   | 142/225 [12:39<10:13,  7.39s/it, loss=1.13, learning_rate=3.09e-6, global_step=35]

[tensor(7.), tensor(8.)]
[tensor(3.), tensor(7.)]


Epoch 1/1 Training:  64%|██████▎   | 143/225 [12:43<08:36,  6.30s/it, loss=1.13, learning_rate=3.09e-6, global_step=35]

[tensor(7.), tensor(8.)]
[tensor(2.), tensor(3.)]


Epoch 1/1 Training:  64%|██████▍   | 144/225 [12:46<07:13,  5.35s/it, loss=1.13, learning_rate=2.83e-6, global_step=36]

[tensor(4.), tensor(8.)]
[tensor(2.), tensor(8.)]


Epoch 1/1 Training:  64%|██████▍   | 145/225 [12:49<06:19,  4.74s/it, loss=1.13, learning_rate=2.83e-6, global_step=36]

[tensor(9.), tensor(7.)]
[tensor(7.), tensor(6.)]


Epoch 1/1 Training:  65%|██████▍   | 146/225 [12:52<05:27,  4.15s/it, loss=1.13, learning_rate=2.83e-6, global_step=36]

[tensor(9.), tensor(7.)]
[tensor(7.), tensor(7.)]


Epoch 1/1 Training:  65%|██████▌   | 147/225 [12:55<04:54,  3.78s/it, loss=1.13, learning_rate=2.83e-6, global_step=36]

[tensor(9.), tensor(7.)]
[tensor(7.), tensor(7.)]


Epoch 1/1 Training:  66%|██████▌   | 148/225 [13:12<09:59,  7.79s/it, loss=1.13, learning_rate=2.58e-6, global_step=37]

[tensor(8.5000), tensor(8.5000)]
[tensor(5.), tensor(4.)]


Epoch 1/1 Training:  66%|██████▌   | 149/225 [13:30<13:35, 10.73s/it, loss=1.13, learning_rate=2.58e-6, global_step=37]

[tensor(8.), tensor(6.5000)]
[tensor(4.), tensor(6.)]


Epoch 1/1 Training:  67%|██████▋   | 150/225 [13:48<16:16, 13.03s/it, loss=1.13, learning_rate=2.58e-6, global_step=37]

[tensor(7.5000), tensor(8.5000)]
[tensor(7.5000), tensor(6.5000)]


Epoch 1/1 Training:  67%|██████▋   | 151/225 [13:51<12:24, 10.06s/it, loss=1.13, learning_rate=2.58e-6, global_step=37]

[tensor(8.), tensor(7.5000)]
[tensor(7.), tensor(5.)]


Epoch 1/1 Training:  68%|██████▊   | 152/225 [13:54<09:42,  7.98s/it, loss=1.13, learning_rate=2.34e-6, global_step=38]

[tensor(9.), tensor(7.5000)]
[tensor(3.), tensor(3.)]


Epoch 1/1 Training:  68%|██████▊   | 153/225 [14:11<12:53, 10.74s/it, loss=1.13, learning_rate=2.34e-6, global_step=38]

[tensor(7.), tensor(8.)]
[tensor(7.), tensor(7.)]


Epoch 1/1 Training:  68%|██████▊   | 154/225 [14:14<09:57,  8.42s/it, loss=1.13, learning_rate=2.34e-6, global_step=38]

[tensor(8.), tensor(7.5000)]
[tensor(6.), tensor(4.)]


Epoch 1/1 Training:  69%|██████▉   | 155/225 [14:17<07:51,  6.73s/it, loss=1.13, learning_rate=2.34e-6, global_step=38]

[tensor(8.5000), tensor(7.5000)]
[tensor(7.), tensor(7.)]


Epoch 1/1 Training:  69%|██████▉   | 156/225 [14:21<06:33,  5.71s/it, loss=1.13, learning_rate=2.11e-6, global_step=39]

[tensor(8.5000), tensor(9.)]
[tensor(7.5000), tensor(8.5000)]


Epoch 1/1 Training:  70%|██████▉   | 157/225 [14:24<05:38,  4.98s/it, loss=1.13, learning_rate=2.11e-6, global_step=39]

[tensor(8.5000), tensor(8.)]
[tensor(8.), tensor(7.)]


Epoch 1/1 Training:  70%|███████   | 158/225 [14:27<04:48,  4.30s/it, loss=1.13, learning_rate=2.11e-6, global_step=39]

[tensor(8.5000), tensor(8.5000)]
[tensor(6.), tensor(7.)]


Epoch 1/1 Training:  71%|███████   | 159/225 [14:29<04:13,  3.85s/it, loss=1.13, learning_rate=2.11e-6, global_step=39]

[tensor(8.5000), tensor(6.5000)]
[tensor(6.), tensor(5.)]


Epoch 1/1 Training:  71%|███████   | 160/225 [14:32<03:54,  3.61s/it, loss=1.13, learning_rate=1.88e-6, global_step=40]

[tensor(8.), tensor(7.)]
[tensor(7.), tensor(6.5000)]


Epoch 1/1 Training:  72%|███████▏  | 161/225 [14:35<03:37,  3.40s/it, loss=1.13, learning_rate=1.88e-6, global_step=40]

[tensor(8.), tensor(7.5000)]
[tensor(3.), tensor(4.)]


Epoch 1/1 Training:  72%|███████▏  | 162/225 [14:38<03:21,  3.20s/it, loss=1.13, learning_rate=1.88e-6, global_step=40]

[tensor(8.), tensor(7.5000)]
[tensor(6.), tensor(6.)]


Epoch 1/1 Training:  72%|███████▏  | 163/225 [14:43<03:42,  3.59s/it, loss=1.13, learning_rate=1.88e-6, global_step=40]

[tensor(8.5000), tensor(8.5000)]
[tensor(7.), tensor(7.)]


Epoch 1/1 Training:  73%|███████▎  | 164/225 [14:46<03:29,  3.44s/it, loss=1.13, learning_rate=1.67e-6, global_step=41]

[tensor(8.5000), tensor(8.)]
[tensor(8.), tensor(8.)]


Epoch 1/1 Training:  73%|███████▎  | 165/225 [15:04<07:49,  7.82s/it, loss=1.13, learning_rate=1.67e-6, global_step=41]

[tensor(8.5000), tensor(7.5000)]
[tensor(3.), tensor(6.)]
