In [None]:
# pip installs

!pip install -q --upgrade torch==2.5.1+cu124 torchvision==0.20.1+cu124 torchaudio==2.5.1+cu124 --index-url https://download.pytorch.org/whl/cu124
!pip install -q --upgrade requests==2.32.3 bitsandbytes==0.46.0 transformers==4.48.3 accelerate==1.3.0 datasets==3.2.0 peft==0.14.0 trl==0.14.0 matplotlib wandb

In [None]:
# imports
import re
import math
import numpy as np
from tqdm import tqdm
from google.colab import userdata
from huggingface_hub import login
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, set_seed
from datasets import load_dataset
from peft import PeftModel
import matplotlib.pyplot as plt

# Auto-detect device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

In [None]:
# Constants

BASE_MODEL = "meta-llama/Meta-Llama-3.1-8B"
PROJECT_NAME = "pricer"
HF_USER = "ed-donner"
RUN_NAME = "2024-09-13_13.04.39"
PROJECT_RUN_NAME = f"{PROJECT_NAME}-{RUN_NAME}"
REVISION = "e8d637df551603dc86cd7a1598a8f44af4d7ae36"
FINETUNED_MODEL = f"{HF_USER}/{PROJECT_RUN_NAME}"


DATASET_NAME = f"{HF_USER}/pricer-data"
# Or just use the one I've uploaded
# DATASET_NAME = "ed-donner/pricer-data"

# Hyperparameters for QLoRA

QUANT_4_BIT = True

%matplotlib inline

# Used for writing to output in color

GREEN = "\033[92m"
YELLOW = "\033[93m"
RED = "\033[91m"
RESET = "\033[0m"
COLOR_MAP = {"red":RED, "orange": YELLOW, "green": GREEN}

In [None]:
# Log in to HuggingFace

hf_token = userdata.get('HF_TOKEN')
login(hf_token, add_to_git_credential=True)

In [None]:
dataset = load_dataset(DATASET_NAME)
train_full = dataset['train']
test_full = dataset['test']

# TRAIN_SIZE = len(train_full)
# TEST_SIZE = len(test_full)

TRAIN_SIZE = 8000  # Very small for testing
TEST_SIZE = 2000    # Very small for testing

train = train_full.select(range(min(TRAIN_SIZE, len(train_full))))
test = test_full.select(range(min(TEST_SIZE, len(test_full))))

print(f"Using small test dataset:")
print(f"  Train samples: {len(train)} (full dataset has {len(train_full)})")
print(f"  Test samples: {len(test)} (full dataset has {len(test_full)})")
print(f"\nTo use full dataset, set TRAIN_SIZE and TEST_SIZE to None or large numbers")

In [None]:
if QUANT_4_BIT:
  quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_quant_type="nf4"
  )
else:
  quant_config = BitsAndBytesConfig(
    load_in_8bit=True,
  )

In [None]:
# Load the Tokenizer and the Model

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    quantization_config=quant_config,
    device_map="auto",
)
base_model.generation_config.pad_token_id = tokenizer.pad_token_id

# Load the fine-tuned model with PEFT
if REVISION:
    fine_tuned_model = PeftModel.from_pretrained(base_model, FINETUNED_MODEL, revision=REVISION)
else:
    fine_tuned_model = PeftModel.from_pretrained(base_model, FINETUNED_MODEL)

fine_tuned_model.eval()

In [None]:
def extract_price(s):
    """Extract price from model output - expects format 'Price is $X.XX'"""
    if not s or not isinstance(s, str):
        return None
    
    if "Price is $" in s:
        contents = s.split("Price is $")[1]
        contents = contents.replace(',', '')  # Remove commas from numbers
        match = re.search(r"[-+]?\d*\.\d+|\d+", contents)
        
        if match:
            try:
                return float(match.group())
            except (ValueError, AttributeError):
                return None
    
    return None

In [None]:
# Original prediction function - greedy decoding (supports batch processing)

def model_predict(prompt, device=device, batch_mode=False):
    """
    Simple greedy prediction with improved generation parameters.
    """
    set_seed(42)
    
    # Handle batch mode
    if batch_mode and isinstance(prompt, list):
        return model_predict_batch(prompt, device)
    
    try:
        inputs = tokenizer.encode(prompt, return_tensors="pt").to(device)
        attention_mask = torch.ones(inputs.shape, device=device)
        
        outputs = fine_tuned_model.generate(
            inputs, 
            attention_mask=attention_mask, 
            max_new_tokens=15,
            num_return_sequences=1,
            temperature=0.1,  # Lower temperature for more deterministic
            do_sample=False,  # Greedy decoding
            pad_token_id=tokenizer.pad_token_id
        )
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        price = extract_price(response)
        return price if price is not None else 0.0
    except Exception as e:
        print(f"Error in model_predict: {e}")
        return 0.0

def model_predict_batch(prompts, device=device):
    """Batch prediction for multiple prompts at once - much faster!"""
    set_seed(42)
    try:
        # Tokenize all prompts at once with padding
        inputs = tokenizer(
            prompts, 
            return_tensors="pt", 
            padding=True, 
            truncation=True,
            max_length=512
        ).to(device)
        
        with torch.no_grad():
            outputs = fine_tuned_model.generate(
                **inputs,
                max_new_tokens=15,
                num_return_sequences=1,
                temperature=0.1,
                do_sample=False,
                pad_token_id=tokenizer.pad_token_id
            )
        
        # Decode all responses
        responses = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        
        # Extract prices for all responses
        prices = []
        for response in responses:
            price = extract_price(response)
            prices.append(price if price is not None else 0.0)
        
        return prices
    except Exception as e:
        print(f"Error in model_predict_batch: {e}")
        return [0.0] * len(prompts)

In [None]:
# Improved prediction function with dual strategy: full generation + fallback to weighted top-K
# Supports batch processing for faster inference

top_K = 6

def improved_model_predict(prompt, device=device, max_tokens=15, batch_mode=False):
    """
    Improved prediction using dual strategy:
    1. Full generation and extract price (handles multi-token prices)
    2. Fallback to weighted average of top-K token probabilities
    
    Args:
        prompt: Single string or list of strings for batch processing
        device: Device to use
        max_tokens: Maximum tokens to generate
        batch_mode: If True and prompt is a list, processes all at once (much faster!)
    """
    # Handle batch mode
    if batch_mode and isinstance(prompt, list):
        return improved_model_predict_batch(prompt, device, max_tokens)
    
    set_seed(42)
    try:
        inputs = tokenizer.encode(prompt, return_tensors="pt").to(device)
        attention_mask = torch.ones(inputs.shape, device=device)

        # Strategy 1: Full generation and extract price (handles multi-token prices)
        with torch.no_grad():
            outputs = fine_tuned_model.generate(
                inputs,
                attention_mask=attention_mask,
                max_new_tokens=max_tokens,
                num_return_sequences=1,
                temperature=0.1,  # Lower temperature for deterministic output
                do_sample=False,  # Greedy decoding
                pad_token_id=tokenizer.pad_token_id
            )
        full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        extracted_price = extract_price(full_response)
        
        if extracted_price is not None and extracted_price > 0:
            return float(extracted_price)
        
        # Strategy 2: Fallback to single-token weighted average
        with torch.no_grad():
            outputs = fine_tuned_model(inputs, attention_mask=attention_mask)
            next_token_logits = outputs.logits[:, -1, :].to('cpu')

        next_token_probs = F.softmax(next_token_logits, dim=-1)
        top_probs, top_token_ids = next_token_probs.topk(top_K)
        
        prices, weights = [], []
        for i in range(top_K):
            predicted_token = tokenizer.decode([top_token_ids[0][i].item()], skip_special_tokens=True)
            probability = top_probs[0][i].item()
            try:
                result = float(predicted_token)
            except ValueError:
                continue
            if result > 0:
                prices.append(result)
                weights.append(probability)
        
        if not prices:
            return 0.0
        
        # Weighted average
        total = sum(weights)
        if total == 0:
            return 0.0
        
        weighted_prices = [price * weight / total for price, weight in zip(prices, weights)]
        return sum(weighted_prices)
        
    except Exception as e:
        print(f"Error in improved_model_predict: {e}")
        return 0.0

def improved_model_predict_batch(prompts, device=device, max_tokens=15):
    """
    Batch version of improved_model_predict - processes multiple prompts in parallel.
    This is MUCH faster than calling improved_model_predict in a loop!
    """
    set_seed(42)
    try:
        # Tokenize all prompts at once with padding
        inputs = tokenizer(
            prompts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512
        ).to(device)
        
        prices = []
        
        # Strategy 1: Full generation for all prompts at once
        with torch.no_grad():
            outputs = fine_tuned_model.generate(
                **inputs,
                max_new_tokens=max_tokens,
                num_return_sequences=1,
                temperature=0.1,
                do_sample=False,
                pad_token_id=tokenizer.pad_token_id
            )
        
        # Decode all responses
        responses = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        
        # Extract prices - try Strategy 1 first
        need_fallback = []
        fallback_indices = []
        
        for idx, response in enumerate(responses):
            extracted_price = extract_price(response)
            if extracted_price is not None and extracted_price > 0:
                prices.append(float(extracted_price))
            else:
                prices.append(None)  # Mark for fallback
                need_fallback.append(prompts[idx])
                fallback_indices.append(idx)
        
        # Strategy 2: Fallback for items that failed Strategy 1
        if need_fallback:
            # Re-encode only the ones that need fallback
            fallback_inputs = tokenizer(
                need_fallback,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=512
            ).to(device)
            
            with torch.no_grad():
                fallback_outputs = fine_tuned_model(**fallback_inputs)
                next_token_logits = fallback_outputs.logits[:, -1, :].to('cpu')
            
            next_token_probs = F.softmax(next_token_logits, dim=-1)
            top_probs, top_token_ids = next_token_probs.topk(top_K)
            
            # Process each fallback item
            for batch_idx, original_idx in enumerate(fallback_indices):
                batch_prices, batch_weights = [], []
                
                for k in range(top_K):
                    predicted_token = tokenizer.decode(
                        [top_token_ids[batch_idx][k].item()], 
                        skip_special_tokens=True
                    )
                    probability = top_probs[batch_idx][k].item()
                    
                    try:
                        result = float(predicted_token)
                    except ValueError:
                        continue
                    
                    if result > 0:
                        batch_prices.append(result)
                        batch_weights.append(probability)
                
                if batch_prices:
                    total = sum(batch_weights)
                    if total > 0:
                        weighted_avg = sum(p * w / total for p, w in zip(batch_prices, batch_weights))
                        prices[original_idx] = weighted_avg
                    else:
                        prices[original_idx] = 0.0
                else:
                    prices[original_idx] = 0.0
        
        # Replace None with 0.0
        return [p if p is not None else 0.0 for p in prices]
        
    except Exception as e:
        print(f"Error in improved_model_predict_batch: {e}")
        return [0.0] * len(prompts)

In [None]:
class Tester:

    def __init__(self, predictor, data, title=None, size=250):
        self.predictor = predictor
        self.data = data
        self.title = title or predictor.__name__.replace("_", " ").title()
        self.size = min(size, len(data)) if data else size
        self.guesses = []
        self.truths = []
        self.errors = []
        self.sles = []
        self.colors = []
        self.relative_errors = []

    def color_for(self, error, truth):
        """Determine color with safe division handling"""
        if truth == 0:
            # If truth is 0, use absolute error only
            if error < 40:
                return "green"
            elif error < 80:
                return "orange"
            else:
                return "red"
        
        relative_error = error / truth
        if error < 40 or relative_error < 0.2:
            return "green"
        elif error < 80 or relative_error < 0.4:
            return "orange"
        else:
            return "red"

    def run_datapoint(self, i):
        """Test a single datapoint"""
        datapoint = self.data[i]
        guess = self.predictor(datapoint["text"])
        truth = float(datapoint["price"])
        
        # Handle invalid guesses (None, tuple, negative)
        if guess is None:
            guess = 0.0
        if isinstance(guess, tuple):
            guess = guess[0] if len(guess) > 0 else 0.0
        if guess < 0:
            guess = 0.0
        
        error = abs(guess - truth)
        relative_error = error / truth if truth > 0 else error
        log_error = math.log(truth + 1) - math.log(guess + 1)
        sle = log_error ** 2
        color = self.color_for(error, truth)
        
        # Extract item title safely
        try:
            title_parts = datapoint["text"].split("\n\n")
            title = (title_parts[1][:40] + "...") if len(title_parts) > 1 else "Unknown"
        except:
            title = "Unknown"
        
        self.guesses.append(guess)
        self.truths.append(truth)
        self.errors.append(error)
        self.relative_errors.append(relative_error)
        self.sles.append(sle)
        self.colors.append(color)
        
        print(f"{COLOR_MAP[color]}{i+1}: Guess: ${guess:,.2f} Truth: ${truth:,.2f} Error: ${error:,.2f} ({relative_error*100:.1f}%) SLE: {sle:.4f} Item: {title}{RESET}")

    def chart(self, title):
        """Create comprehensive visualization"""
        fig, axes = plt.subplots(2, 2, figsize=(16, 12))
        
        # 1. Scatter plot: Predictions vs Truth
        ax1 = axes[0, 0]
        max_val = max(max(self.truths), max(self.guesses)) * 1.1
        ax1.plot([0, max_val], [0, max_val], color='deepskyblue', lw=2, alpha=0.6, label='Perfect prediction')
        ax1.scatter(self.truths, self.guesses, s=20, c=self.colors, alpha=0.6)
        ax1.set_xlabel('Ground Truth Price ($)', fontsize=12)
        ax1.set_ylabel('Predicted Price ($)', fontsize=12)
        ax1.set_xlim(0, max_val)
        ax1.set_ylim(0, max_val)
        ax1.set_title('Predictions vs Ground Truth', fontsize=14)
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # 2. Error distribution histogram
        ax2 = axes[0, 1]
        ax2.hist(self.errors, bins=30, color='skyblue', alpha=0.7, edgecolor='black')
        ax2.axvline(np.mean(self.errors), color='red', linestyle='--', label='Mean Error')
        ax2.set_xlabel('Absolute Error ($)', fontsize=12)
        ax2.set_ylabel('Frequency', fontsize=12)
        ax2.set_title('Error Distribution', fontsize=14)
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        # 3. Relative error distribution
        ax3 = axes[1, 0]
        relative_errors_pct = [e * 100 for e in self.relative_errors]
        ax3.hist(relative_errors_pct, bins=30, color='lightcoral', alpha=0.7, edgecolor='black')
        ax3.set_xlabel('Relative Error (%)', fontsize=12)
        ax3.set_ylabel('Frequency', fontsize=12)
        ax3.set_title('Relative Error Distribution', fontsize=14)
        ax3.grid(True, alpha=0.3)
        
        # 4. Accuracy by price range
        ax4 = axes[1, 1]
        price_ranges = [(0, 50), (50, 100), (100, 200), (200, 500), (500, float('inf'))]
        range_errors = []
        range_labels = []
        for low, high in price_ranges:
            range_indices = [i for i, t in enumerate(self.truths) if low <= t < high]
            if range_indices:
                avg_error = np.mean([self.errors[i] for i in range_indices])
                range_errors.append(avg_error)
                range_labels.append(f"${low}-${high if high != float('inf') else '+'}")
        
        ax4.bar(range_labels, range_errors, color='steelblue', alpha=0.7)
        ax4.set_xlabel('Price Range ($)', fontsize=12)
        ax4.set_ylabel('Average Error ($)', fontsize=12)
        ax4.set_title('Average Error by Price Range', fontsize=14)
        ax4.tick_params(axis='x', rotation=45)
        ax4.grid(True, alpha=0.3, axis='y')
        
        plt.tight_layout()
        plt.suptitle(title, fontsize=16, y=1.02)
        plt.show()

    def calculate_metrics(self):
        """Calculate comprehensive evaluation metrics"""
        guesses_arr = np.array(self.guesses)
        truths_arr = np.array(self.truths)
        errors_arr = np.array(self.errors)
        
        metrics = {
            'mae': np.mean(errors_arr),  # Mean Absolute Error
            'median_error': np.median(errors_arr),
            'rmse': np.sqrt(np.mean(errors_arr ** 2)),  # Root Mean Squared Error
            'rmsle': math.sqrt(sum(self.sles) / self.size),
            'mape': np.mean([abs(e) if t > 0 else 0 for e, t in zip(errors_arr/truths_arr, truths_arr)]) * 100,
        }
        
        # R² (coefficient of determination)
        ss_res = np.sum((truths_arr - guesses_arr) ** 2)
        ss_tot = np.sum((truths_arr - np.mean(truths_arr)) ** 2)
        metrics['r2'] = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0
        
        # Hit rates
        hits_green = sum(1 for c in self.colors if c == "green")
        hits_orange_green = sum(1 for c in self.colors if c in ["green", "orange"])
        metrics['hit_rate_green'] = hits_green / self.size * 100
        metrics['hit_rate_acceptable'] = hits_orange_green / self.size * 100
        
        return metrics

    def report(self):
        """Generate comprehensive report"""
        metrics = self.calculate_metrics()
        
        print(f"\n{'='*70}")
        print(f"FINAL REPORT: {self.title}")
        print(f"{'='*70}")
        print(f"Total Predictions: {self.size}")
        print(f"\n--- Error Metrics ---")
        print(f"Mean Absolute Error (MAE):      ${metrics['mae']:,.2f}")
        print(f"Median Error:                   ${metrics['median_error']:,.2f}")
        print(f"Root Mean Squared Error (RMSE): ${metrics['rmse']:,.2f}")
        print(f"Root Mean Squared Log Error:    {metrics['rmsle']:.4f}")
        print(f"Mean Absolute Percentage Error: {metrics['mape']:.2f}%")
        print(f"\n--- Accuracy Metrics ---")
        print(f"R² Score (Coefficient of Determination): {metrics['r2']:.4f}")
        print(f"Hit Rate (Green - Excellent):   {metrics['hit_rate_green']:.1f}%")
        print(f"Hit Rate (Green+Orange - Good): {metrics['hit_rate_acceptable']:.1f}%")
        print(f"{'='*70}\n")
        
        # Create visualization
        chart_title = f"{self.title} | MAE=${metrics['mae']:,.2f} | RMSLE={metrics['rmsle']:.3f} | R²={metrics['r2']:.3f}"
        self.chart(chart_title)
        
        return metrics

    def run(self, show_progress=True, batch_size=8):
        """
        Run test on all datapoints with progress bar.
        
        Args:
            show_progress: Show progress bar
            batch_size: Process this many items at once (0 = no batching, process one by one)
        """
        print(f"Testing {self.size} predictions with {self.title}...\n")
        
        if batch_size > 1:
            # Batch processing mode - much faster!
            print(f"Using batch processing with batch_size={batch_size}")
            texts = [self.data[i]["text"] for i in range(self.size)]
            
            iterator = tqdm(range(0, self.size, batch_size), desc="Batch Predicting") if show_progress else range(0, self.size, batch_size)
            
            for batch_start in iterator:
                batch_end = min(batch_start + batch_size, self.size)
                batch_texts = texts[batch_start:batch_end]
                
                # Get batch predictions
                batch_guesses = self.predictor(batch_texts, batch_mode=True)
                
                # Process each result in the batch
                for i, guess in enumerate(batch_guesses):
                    actual_idx = batch_start + i
                    self.run_datapoint_internal(actual_idx, guess)
        else:
            # Sequential processing (original method)
            iterator = tqdm(range(self.size), desc="Predicting") if show_progress else range(self.size)
            for i in iterator:
                self.run_datapoint(i)
        
        return self.report()
    
    def run_datapoint_internal(self, i, guess):
        """Internal method to process a single datapoint when we already have the guess"""
        datapoint = self.data[i]
        truth = float(datapoint["price"])
        
        # Handle invalid guesses (None, tuple, negative)
        if guess is None:
            guess = 0.0
        if isinstance(guess, tuple):
            guess = guess[0] if len(guess) > 0 else 0.0
        if guess < 0:
            guess = 0.0
        
        error = abs(guess - truth)
        relative_error = error / truth if truth > 0 else error
        log_error = math.log(truth + 1) - math.log(guess + 1)
        sle = log_error ** 2
        color = self.color_for(error, truth)
        
        # Extract item title safely
        try:
            title_parts = datapoint["text"].split("\n\n")
            title = (title_parts[1][:40] + "...") if len(title_parts) > 1 else "Unknown"
        except:
            title = "Unknown"
        
        self.guesses.append(guess)
        self.truths.append(truth)
        self.errors.append(error)
        self.relative_errors.append(relative_error)
        self.sles.append(sle)
        self.colors.append(color)
        
        print(f"{COLOR_MAP[color]}{i+1}: Guess: ${guess:,.2f} Truth: ${truth:,.2f} Error: ${error:,.2f} ({relative_error*100:.1f}%) SLE: {sle:.4f} Item: {title}{RESET}")

    @classmethod
    def test(cls, function, data, title=None, size=250, batch_size=8):
        """Quick test method with optional batch processing"""
        return cls(function, data, title, size).run(batch_size=batch_size)

In [None]:
test_size = len(test)
batch_size = 1 # increase to 2 for faster processing

print(f"Running test with {test_size} samples, batch_size={batch_size}")

results = Tester.test(
    improved_model_predict, 
    test, 
    title="GPT-4o-mini Fine-tuned (Improved - Test Mode)",
    size=test_size,
    batch_size=batch_size
)