In [None]:
import os
import gc
import logging
import random
import string
import re
from abc import ABC, abstractmethod

import networkx as nx
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
import itertools

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logging.getLogger("transformers").setLevel(logging.ERROR)

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed()
print(f"Setup complete. Using device: {DEVICE}")


In [None]:
class NavigationGPT:
    def __init__(self, model_path: str):
        self.tokenizer = GPT2TokenizerFast.from_pretrained(model_path)
        if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token
        self.model = GPT2LMHeadModel.from_pretrained(model_path, torch_dtype=torch.float32)
        self.model.config.pad_token_id = self.tokenizer.pad_token_id
        self.model.to(DEVICE)
        self.model.eval()

    def generate(self, prompt: str, max_new_tokens: int = 60):
        encodings = self.tokenizer(prompt, return_tensors='pt').to(DEVICE)
        with torch.no_grad():
            output_ids = self.model.generate(**encodings, max_new_tokens=max_new_tokens, num_beams=1)
        return self.tokenizer.decode(output_ids[0, encodings.input_ids.shape[1]:], skip_special_tokens=True).strip()

def generate_grid_graph(size=5):
    G = nx.DiGraph()
    num_nodes = size * size
    nodes = list(set([''.join(random.choices(string.ascii_lowercase, k=2)) for _ in range(num_nodes*2)]))[:num_nodes]
    node_map = [['' for _ in range(size)] for _ in range(size)]
    for r in range(size):
        for c in range(size):
            idx = r * size + c
            node_map[r][c] = nodes[idx]
            if c < size - 1: G.add_edge(nodes[idx], nodes[idx+1], direction='EAST')
            if c > 0: G.add_edge(nodes[idx], nodes[idx-1], direction='WEST')
            if r < size - 1: G.add_edge(nodes[idx], nodes[idx+size], direction='SOUTH')
            if r > 0: G.add_edge(nodes[idx], nodes[idx-size], direction='NORTH')
    return G, nodes, node_map

def generate_simple_random_walk(G, length):
    start_node = random.choice(list(G.nodes()))
    path = [start_node]
    for _ in range(length - 1):
        neighbors = list(G.successors(path[-1]))
        valid_neighbors = [n for n in neighbors if n not in path]
        if not valid_neighbors:
            return None
        path.append(random.choice(valid_neighbors))
    return path

def walk_to_string(walk, G):
    if not walk or len(walk) < 2: return walk[0] if walk else ""
    return " ".join([f"{walk[i]} {G.edges[walk[i], walk[i+1]]['direction']}" for i in range(len(walk)-1)] + [walk[-1]])

def get_node_coords(node_name, node_map):
    for r, row in enumerate(node_map):
        if node_name in row: return (r, row.index(node_name))
    return None

def get_manhattan_distance(n1, n2, node_map):
    c1, c2 = get_node_coords(n1, node_map), get_node_coords(n2, node_map)
    return abs(c1[0] - c2[0]) + abs(c1[1] - c2[1]) if c1 and c2 else float('inf')

def create_test_cases(params):
    test_cases = []
    task_type = params.get("task_type")
    
    pbar = tqdm(total=params['num_tests'], desc=f"Generating tasks for {task_type}")
    attempts, max_total_attempts = 0, params['num_tests'] * 500

    while len(test_cases) < params['num_tests'] and attempts < max_total_attempts:
        attempts += 1
        G, nodes, node_map = generate_grid_graph(params['grid_size'])
        
        if task_type == 'opposite_edge':
            edge, size = [0, params['grid_size'] - 1], params['grid_size']
            if random.random() > 0.5:
                r = random.randint(0, size - 1)
                start_node, end_node = node_map[r][edge[0]], node_map[r][edge[1]]
            else:
                c = random.randint(0, size - 1)
                start_node, end_node = node_map[edge[0]][c], node_map[edge[1]][c]
        elif task_type == 'high_manhattan_distance':
            s, e = random.sample(nodes, 2)
            if get_manhattan_distance(s, e, node_map) < params.get('min_md', 7):
                continue
            start_node, end_node = s, e
        else:
            logging.error(f"Unknown task_type: {task_type}")
            return []

        try:
            all_solution_paths = list(nx.all_shortest_paths(G, start_node, end_node))
            if not all_solution_paths: continue
            list_of_required_nodes = [set(p) for p in all_solution_paths]
        except nx.NetworkXNoPath:
            continue

        context_found = False
        for _ in range(250):
            context = generate_simple_random_walk(G, params['context_walk_length'])
            
            if context is None: continue

            nodes_in_walk = set(context)
            if any(req.issubset(nodes_in_walk) for req in list_of_required_nodes):
                context_found = True
                break
        
        if not context_found:
            continue

        test_cases.append({
            'graph': G, 
            'node_map': node_map, 
            'start': start_node, 
            'end': end_node, 
            'context': context
        })
        pbar.update(1)

    pbar.close()
    if len(test_cases) < params['num_tests']:
        logging.warning(f"Warning: Only generated {len(test_cases)}/{params['num_tests']} cases for {task_type}.")
    return test_cases

def parse_path(text): return re.findall(r'\b[a-z]{2}\b', text)
def is_valid_path(nodes, G): return all(G.has_edge(nodes[i], nodes[i+1]) for i in range(len(nodes)-1))

def score_and_analyze(parsed_nodes, task):
    try:
        all_shortest_paths = list(nx.all_shortest_paths(task['graph'], task['start'], task['end']))
        expected_len = len(all_shortest_paths[0])
    except (nx.NetworkXNoPath, nx.NodeNotFound):
        all_shortest_paths, expected_len = [], -1
    
    is_path_valid = is_valid_path(parsed_nodes, task['graph'])
    is_correct = (is_path_valid and parsed_nodes and all_shortest_paths and
                  parsed_nodes[0] == task['start'] and parsed_nodes[-1] == task['end'] and
                  parsed_nodes in all_shortest_paths)
    
    try:
        path_dist_in_context = abs(task['context'].index(task['end']) - task['context'].index(task['start']))
    except (ValueError, IndexError):
        path_dist_in_context = float('inf')

    return {
        'accuracy': 1.0 if is_correct else 0.0,
        'is_valid': is_path_valid,
        'expected_len': expected_len,
        'generated_len': len(parsed_nodes),
        'manhattan_distance': get_manhattan_distance(task['start'], task['end'], task['node_map']),
        'path_distance_in_context': path_dist_in_context
    }

class PromptStrategy(ABC):
    @abstractmethod
    def create_prompt(self, task: dict) -> str:
        pass

class StandardInstructionalStrategy(PromptStrategy):
    def create_prompt(self, task: dict) -> str:
        map_context_str = walk_to_string(task['context'], task['graph'])
        instruction = f"[SHORTEST] [START_NODE] {task['start']} [GOAL] {task['end']}"
        return f"[SOS] {map_context_str} [SEP] {instruction} [PLAN]"

def run_advanced_analysis(model_config: dict, test_cases: list, model_name: str = "model"):
    if not test_cases:
        logging.warning(f"Skipping evaluation for {model_name} as no test cases were provided.")
        return pd.DataFrame()
    try:
        model = NavigationGPT(model_config['path'])
    except Exception as e:
        logging.error(f"SKIPPING {model_name}: Failed to load model. Error: {e}")
        return pd.DataFrame()

    results = []
    for task in tqdm(test_cases, desc=f"Evaluating {model_name}"):
        prompt = model_config['strategy'].create_prompt(task)
        generated_text = model.generate(prompt)
        parsed_nodes = parse_path(generated_text)
        analysis_metrics = score_and_analyze(parsed_nodes, task)
        results.append(analysis_metrics)

    del model
    gc.collect()
    torch.cuda.empty_cache()
    return pd.DataFrame(results)
    
class ExperimentLogger:
    def __init__(self):
        self._summary_results = []
        self._detailed_results_list = []

    def add(self, model_name: str, analysis_name: str, results_df: pd.DataFrame):
        if results_df.empty:
            logging.warning(f"Skipped logging for {model_name}/{analysis_name} due to empty results.")
            return
        accuracy = results_df['accuracy'].mean()
        self._summary_results.append({
            "Model": model_name, "Analysis": analysis_name, "Accuracy": accuracy
        })
        logging.info(f"Logged: Model={model_name}, Analysis={analysis_name}, Accuracy={accuracy:.2%}")
        
        detailed_df_copy = results_df.copy()
        detailed_df_copy['Model'] = model_name
        detailed_df_copy['Analysis'] = analysis_name
        self._detailed_results_list.append(detailed_df_copy)

    def get_summary_df(self) -> pd.DataFrame:
        if not self._summary_results: return pd.DataFrame()
        summary_df = pd.DataFrame(self._summary_results)
        return summary_df.pivot_table(index='Model', columns='Analysis', values='Accuracy')

    def get_detailed_df(self, analysis_name: str = None) -> pd.DataFrame:
        if not self._detailed_results_list: return pd.DataFrame()
        full_df = pd.concat(self._detailed_results_list, ignore_index=True)
        return full_df[full_df['Analysis'] == analysis_name] if analysis_name else full_df

def plot_stratified_results(combined_df: pd.DataFrame, analysis_name: str):
    if combined_df.empty:
        logging.warning(f"Skipping plot for '{analysis_name}' due to empty data.")
        return
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 6), sharey=True)
    fig.suptitle(f'Performance Analysis: {analysis_name}', fontsize=16)

    sns.lineplot(data=combined_df, x='manhattan_distance', y='accuracy', hue='Model', errorbar='ci', ax=axes[0], marker='o')
    axes[0].set_title('Accuracy vs. Path Difficulty (Manhattan Distance)')
    axes[0].set_xlabel('Manhattan Distance between Start/End')
    axes[0].set_ylabel('Accuracy')
    axes[0].set_ylim(-0.05, 1.05)
    axes[0].grid(True, linestyle='--', alpha=0.6)

    df_filtered = combined_df[combined_df['path_distance_in_context'] != float('inf')]
    if not df_filtered.empty:
        sns.lineplot(data=df_filtered, x='path_distance_in_context', y='accuracy', hue='Model', errorbar='ci', ax=axes[1], marker='o')
        axes[1].set_title('Accuracy vs. Contextual Proximity')
        axes[1].set_xlabel('Start-End Distance in Context Path')
    else:
        axes[1].set_title('No Valid In-Context Paths Found')
        
    axes[1].set_ylabel('')
    axes[1].grid(True, linestyle='--', alpha=0.6)
    axes[1].legend().set_visible(False)

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

print("Core functions and classes loaded.")


In [None]:
MODELS_TO_ANALYZE = {
    "SP-RW": {
        "path": "/cs/student/projects1/aibh/2024/cbaumgar/MSC_THESIS/sv3_model_ft/checkpoint-163000",
        "strategy": StandardInstructionalStrategy()
    },
    "SP-Hamiltonian": {
        "path": "/cs/student/projects1/aibh/2024/cbaumgar/MSC_THESIS/sv2_model_fixed/save-checkpoint-46000",
        "strategy": StandardInstructionalStrategy()
    },
}

ANALYSIS_PARAMS = {
    "Generalization_5x5_OppositeEdge": {
        "num_tests": 1000, 
        "grid_size": 5, 
        "context_walk_length": 16,
        "task_type": "opposite_edge",
    },
    "Generalization_5x5_HighMD": {
        "num_tests": 1000, 
        "grid_size": 5, 
        "context_walk_length": 16,
        "task_type": "high_manhattan_distance",
        "min_md": 7
    }
}

logger = ExperimentLogger()
print("Configuration loaded.")


In [None]:
for analysis_name, params in ANALYSIS_PARAMS.items():
    print(f"\n{'='*80}\nGenerating shared test cases for '{analysis_name}'\n{'='*80}")
    shared_test_cases = create_test_cases(params)

    if not shared_test_cases:
        logging.warning(f"Skipping analysis '{analysis_name}' as no test cases could be generated.")
        continue

    for model_name, config in MODELS_TO_ANALYZE.items():
        print(f"\n--- Running model: {model_name.upper()} on {analysis_name} ---")
        
        results_df = run_advanced_analysis(config, shared_test_cases, model_name)
        logger.add(model_name, analysis_name, results_df)

    combined_df_for_plotting = logger.get_detailed_df(analysis_name=analysis_name)
    if not combined_df_for_plotting.empty:
        plot_stratified_results(combined_df_for_plotting, analysis_name)

print("\n\nAll analyses complete.")


In [None]:
print("\n\n" + "="*80 + "\n" + " " * 20 + "FINAL GENERALIZATION PERFORMANCE SUMMARY" + "\n" + "="*80)

summary_df = logger.get_summary_df()

if summary_df.empty:
    print("No analysis results to summarize.")
else:
    print("\n--- Summary Table ---")
    styled_table = summary_df.style.format("{:.2%}", na_rep="-").background_gradient(
        cmap='viridis', vmin=0, vmax=1
    ).set_caption("Model Generalization Performance Summary")
    
    from IPython.display import display
    display(styled_table)
