In [None]:
# Example evaluation code structure
from transformers import pipeline
from datasets import load_metric
import pandas as pd

# Load your fine-tuned model
model = pipeline("text-generation", model="your_fine_tuned_model")

# Load metrics
bleu = load_metric("bleu")
rouge = load_metric("rouge")

def evaluate_model(test_cases):
    results = []
    
    for input_text, expected_output in test_cases:
        # Generate prediction
        prediction = model(input_text)[0]['generated_text']
        
        # Calculate metrics
        bleu_score = bleu.compute(predictions=[prediction], 
                                references=[[expected_output]])
        rouge_score = rouge.compute(predictions=[prediction], 
                                  references=[expected_output])
        
        # Check rule compliance
        rule_compliance = check_rule_compliance(prediction)
        
        results.append({
            'input': input_text,
            'expected': expected_output,
            'predicted': prediction,
            'bleu_score': bleu_score,
            'rouge_score': rouge_score,
            'rule_compliance': rule_compliance
        })
    
    return pd.DataFrame(results)

def check_rule_compliance(text):
    # Implement checks for:
    # 1. Correct road type usage
    # 2. Proper direction terminology
    # 3. Format compliance
    # 4. Emergency information timing
    pass

# Example usage
test_cases = [
    ("accident on A1 near Koper, traffic jam 2km", 
     "Na primorski avtocesti proti Kopru je zaradi nesreče nastal zastoj dolg 2 kilometra med razcepom Kozarje in priključkom Koper."),
    # Add more test cases
]

results = evaluate_model(test_cases)
print(results)

In [None]:
from traffic_parser import TrafficReportParser
from traffic_schema import TrafficEvent, RoadType, EventType

def validate_model_output(predicted_text: str) -> dict:
    """
    Validate the model's output using the TrafficReportParser
    Returns a dictionary with validation results
    """
    parser = TrafficReportParser()
    validation_results = {
        "is_valid": False,
        "parsed_event": None,
        "errors": []
    }
    
    try:
        # Try to parse the generated text
        event = parser.parse_report(predicted_text, "VALIDATION-001")
        
        if event:
            validation_results["is_valid"] = True
            validation_results["parsed_event"] = event.dict()
            
            # Additional validation checks
            if not event.road_section:
                validation_results["errors"].append("Missing road section information")
            if not event.event_type:
                validation_results["errors"].append("Missing event type")
            if not event.reason:
                validation_results["errors"].append("Missing reason for event")
            if not event.consequence:
                validation_results["errors"].append("Missing consequence")
                
        else:
            validation_results["errors"].append("Failed to parse report")
            
    except Exception as e:
        validation_results["errors"].append(f"Parsing error: {str(e)}")
        
    return validation_results

In [None]:
def evaluate_model_output(input_text: str, predicted_text: str, expected_text: str) -> dict:
    """
    Comprehensive evaluation of model output
    """
    results = {
        "validation": validate_model_output(predicted_text),
        "metrics": {},
        "rule_compliance": {}
    }
    
    # Basic text similarity metrics
    results["metrics"]["exact_match"] = predicted_text == expected_text
    
    # Rule compliance checks
    parser = TrafficReportParser()
    predicted_event = parser.parse_report(predicted_text, "EVAL-001")
    expected_event = parser.parse_report(expected_text, "EXPECTED-001")
    
    if predicted_event and expected_event:
        # Check road type compliance
        results["rule_compliance"]["road_type"] = (
            predicted_event.road_section.road_type == 
            expected_event.road_section.road_type
        )
        
        # Check event type hierarchy compliance
        results["rule_compliance"]["event_priority"] = (
            predicted_event.priority <= expected_event.priority
        )
        
        # Check direction terminology
        results["rule_compliance"]["direction"] = (
            predicted_event.road_section.direction.to_location == 
            expected_event.road_section.direction.to_location
        )
        
        # Check emergency information timing
        if predicted_event.event_type in [EventType.WRONG_WAY_DRIVER, 
                                        EventType.CLOSED_MOTORWAY,
                                        EventType.ACCIDENT_WITH_JAM]:
            results["rule_compliance"]["emergency_timing"] = True  # Should be updated every 15-20 min
            
        # Check traffic jam reporting rules
        if predicted_event.jam_length:
            results["rule_compliance"]["jam_length"] = predicted_event.jam_length >= 1.0
    
    return results

In [None]:
def create_test_suite():
    return [
        {
            "input": "accident on A1 near Koper, traffic jam 2km",
            "expected": "Na primorski avtocesti proti Kopru je zaradi nesreče nastal zastoj dolg 2 kilometra med razcepom Kozarje in priključkom Koper.",
            "description": "Standard accident report"
        },
        {
            "input": "wrong way driver on A2 near Maribor",
            "expected": "Opozarjamo vse voznike, ki vozijo po štajerski avtocesti proti Mariboru, da je na njihovo polovico avtoceste zašel voznik, ki vozi v napačno smer. Vozite skrajno desno in ne prehitevajte.",
            "description": "Wrong way driver emergency"
        },
        # Add more test cases
    ]

def run_evaluation(model, test_suite):
    results = []
    for test_case in test_suite:
        # Generate prediction
        prediction = model(test_case["input"])[0]['generated_text']
        
        # Evaluate
        evaluation = evaluate_model_output(
            test_case["input"],
            prediction,
            test_case["expected"]
        )
        
        results.append({
            "test_case": test_case["description"],
            "input": test_case["input"],
            "prediction": prediction,
            "evaluation": evaluation
        })
    
    return results

In [None]:
# Example usage
test_suite = create_test_suite()
evaluation_results = run_evaluation(your_model, test_suite)

# Print results
for result in evaluation_results:
    print(f"\nTest Case: {result['test_case']}")
    print(f"Input: {result['input']}")
    print(f"Prediction: {result['prediction']}")
    print("Validation:", result['evaluation']['validation'])
    print("Rule Compliance:", result['evaluation']['rule_compliance'])

In [None]:
from enum import Enum
from datetime import datetime, time

class CommonEvents(str, Enum):
    MORNING_RUSH_HOUR = "morning_rush_hour"  # common traffic jams in the morning
    EVENING_RUSH_HOUR = "evening_rush_hour"  # common traffic jams in the evening
    HOLIDAY_TRAFFIC = "holiday_traffic"  # common traffic jams during holidays

class TrafficReportValidator:
    def __init__(self):
        # ... (previous initialization code) ...
        
        # Common traffic patterns
        self.common_traffic_patterns = {
            "MORNING_RUSH": {
                "time_range": (time(6, 0), time(9, 0)),  # 6:00 - 9:00
                "locations": {
                    "ŠTAJERSKA AVTOCESTA": {
                        "description": "Morning rush hour on Štajerska avtocesta",
                        "min_jam_length": 3.0,  # Minimum jam length to report
                        "update_frequency": 30  # Update frequency in minutes
                    }
                }
            },
            "EVENING_RUSH": {
                "time_range": (time(15, 0), time(18, 0)),  # 15:00 - 18:00
                "locations": {
                    "LJUBLJANSKA SEVERNA OBVOZNICA": {
                        "description": "Evening rush hour on northern ring road",
                        "min_jam_length": 3.0,
                        "update_frequency": 30
                    },
                    "LJUBLJANSKA JUŽNA OBVOZNICA": {
                        "description": "Evening rush hour on southern ring road",
                        "min_jam_length": 3.0,
                        "update_frequency": 30
                    }
                }
            },
            "HOLIDAY_TRAFFIC": {
                "time_range": None,  # Special handling for holidays
                "locations": {
                    "PRIMORSKA AVTOCESTA": {
                        "description": "Holiday traffic to/from coast",
                        "min_jam_length": 2.0,
                        "update_frequency": 15
                    },
                    "GORENJSKA AVTOCESTA": {
                        "description": "Holiday traffic to/from Austria",
                        "min_jam_length": 2.0,
                        "update_frequency": 15
                    }
                }
            }
        }

    def is_common_traffic(self, event: TrafficEvent) -> bool:
        """Check if the event is a common traffic pattern"""
        current_time = datetime.now().time()
        
        # Check morning rush hour
        if (self.common_traffic_patterns["MORNING_RUSH"]["time_range"][0] <= current_time <= 
            self.common_traffic_patterns["MORNING_RUSH"]["time_range"][1]):
            if event.road_section.road_name in self.common_traffic_patterns["MORNING_RUSH"]["locations"]:
                if event.jam_length and event.jam_length >= self.common_traffic_patterns["MORNING_RUSH"]["locations"][event.road_section.road_name]["min_jam_length"]:
                    return True
        
        # Check evening rush hour
        if (self.common_traffic_patterns["EVENING_RUSH"]["time_range"][0] <= current_time <= 
            self.common_traffic_patterns["EVENING_RUSH"]["time_range"][1]):
            if event.road_section.road_name in self.common_traffic_patterns["EVENING_RUSH"]["locations"]:
                if event.jam_length and event.jam_length >= self.common_traffic_patterns["EVENING_RUSH"]["locations"][event.road_section.road_name]["min_jam_length"]:
                    return True
        
        # Check holiday traffic (would need additional holiday detection logic)
        if self._is_holiday():
            if event.road_section.road_name in self.common_traffic_patterns["HOLIDAY_TRAFFIC"]["locations"]:
                if event.jam_length and event.jam_length >= self.common_traffic_patterns["HOLIDAY_TRAFFIC"]["locations"][event.road_section.road_name]["min_jam_length"]:
                    return True
        
        return False

    def _is_holiday(self) -> bool:
        """Check if current date is a holiday"""
        # This would need to be implemented with actual holiday dates
        # For now, just a placeholder
        return False

    def validate_traffic_jam_report(self, event: TrafficEvent) -> dict:
        """Validate traffic jam reports considering common patterns"""
        validation = {
            "should_report": True,
            "reason": "",
            "update_frequency": None
        }
        
        # Check if it's a common traffic pattern
        if self.is_common_traffic(event):
            pattern = self._get_traffic_pattern(event)
            if event.jam_length < pattern["min_jam_length"]:
                validation["should_report"] = False
                validation["reason"] = "Common traffic pattern with insufficient jam length"
            else:
                validation["update_frequency"] = pattern["update_frequency"]
                validation["reason"] = "Common traffic pattern with significant jam length"
        else:
            # For non-common traffic, use standard reporting rules
            if event.jam_length and event.jam_length >= 1.0:  # Standard minimum jam length
                validation["should_report"] = True
                validation["reason"] = "Non-common traffic with significant jam length"
            else:
                validation["should_report"] = False
                validation["reason"] = "Insufficient jam length for reporting"
        
        return validation

    def _get_traffic_pattern(self, event: TrafficEvent) -> dict:
        """Get the traffic pattern configuration for the given event"""
        current_time = datetime.now().time()
        
        # Check morning rush
        if (self.common_traffic_patterns["MORNING_RUSH"]["time_range"][0] <= current_time <= 
            self.common_traffic_patterns["MORNING_RUSH"]["time_range"][1]):
            if event.road_section.road_name in self.common_traffic_patterns["MORNING_RUSH"]["locations"]:
                return self.common_traffic_patterns["MORNING_RUSH"]["locations"][event.road_section.road_name]
        
        # Check evening rush
        if (self.common_traffic_patterns["EVENING_RUSH"]["time_range"][0] <= current_time <= 
            self.common_traffic_patterns["EVENING_RUSH"]["time_range"][1]):
            if event.road_section.road_name in self.common_traffic_patterns["EVENING_RUSH"]["locations"]:
                return self.common_traffic_patterns["EVENING_RUSH"]["locations"][event.road_section.road_name]
        
        # Check holiday traffic
        if self._is_holiday():
            if event.road_section.road_name in self.common_traffic_patterns["HOLIDAY_TRAFFIC"]["locations"]:
                return self.common_traffic_patterns["HOLIDAY_TRAFFIC"]["locations"][event.road_section.road_name]
        
        return None