# Training Log Interpreter

This notebook parses and visualizes training logs from the Int2Int model training.
It extracts model parameters, task information, and plots metrics over epochs.

In [None]:
import json
import re
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from pathlib import Path

## Configuration

Specify the path to your training log file:

In [None]:
# train log file path
log_file_path = "models/basic/1/train.log"

## Parse Training Log

In [None]:
def parse_train_log(log_path):
    """
    Parse training log file to extract parameters and metrics.
    
    Returns:
        params: dict of training parameters
        metrics: list of dicts containing epoch metrics
    """
    params = {}
    metrics = []
    
    with open(log_path, 'r') as f:
        content = f.read()
    
    lines = content.split('\n')
    
    # Parse parameters (they appear at the beginning of the log)
    in_params_section = False
    for line in lines:
        # Detect start of parameters section
        if 'Initialized logger' in line:
            in_params_section = True
            continue
        
        # Detect end of parameters section
        if in_params_section and ('Running command' in line or 'Starting epoch' in line):
            in_params_section = False
        
        # Parse parameter lines
        if in_params_section:
            # Parameters are in format: "key: value"
            match = re.search(r'^\s+(\w+):\s+(.+)$', line)
            if match:
                key = match.group(1)
                value = match.group(2).strip()
                params[key] = value
        
        # Parse metric lines (contain __log__:)
        if '__log__:' in line:
            # Extract JSON part after __log__:
            json_str = line.split('__log__:', 1)[1]
            try:
                metric_data = json.loads(json_str)
                metrics.append(metric_data)
            except json.JSONDecodeError as e:
                print(f"Warning: Could not parse line: {line[:100]}...")
                print(f"Error: {e}")
    
    return params, metrics


# Parse the log file
params, metrics = parse_train_log(log_file_path)

print(f"Parsed {len(params)} parameters and {len(metrics)} epoch logs")

## Report Model Configuration and Training Parameters

In [None]:
def determine_task_type(params):
    """
    Determine if this is mu or musq task based on training/eval data paths.
    """
    train_data = params.get('train_data', '')
    eval_data = params.get('eval_data', '')
    
    # Check for musq in the data paths
    if 'musq' in train_data.lower() or 'musq' in eval_data.lower():
        return 'musq'
    elif 'mu' in train_data.lower() or 'mu' in eval_data.lower():
        return 'mu'
    else:
        return 'unknown'


def print_model_config(params):
    """
    Print key model configuration and training parameters.
    """
    task_type = determine_task_type(params)
    
    print("=" * 80)
    print("MODEL CONFIGURATION AND TRAINING PARAMETERS")
    print("=" * 80)
    
    print("\nüìã EXPERIMENT INFO")
    print(f"  Experiment Name: {params.get('exp_name', 'N/A')}")
    print(f"  Experiment ID: {params.get('exp_id', 'N/A')}")
    print(f"  Task Type: {task_type.upper()}")
    print(f"  Operation: {params.get('operation', 'N/A')}")
    
    print("\nüìä DATA")
    print(f"  Training Data: {params.get('train_data', 'N/A')}")
    print(f"  Eval Data: {params.get('eval_data', 'N/A')}")
    print(f"  Data Types: {params.get('data_types', 'N/A')}")
    print(f"  Base: {params.get('base', 'N/A')}")
    print(f"  Modulus: {params.get('modulus', 'N/A')}")
    
    print("\nüèóÔ∏è MODEL ARCHITECTURE")
    print(f"  Architecture: {params.get('architecture', 'N/A')}")
    print(f"  Encoder Layers: {params.get('n_enc_layers', 'N/A')}")
    print(f"  Decoder Layers: {params.get('n_dec_layers', 'N/A')}")
    print(f"  Encoder Embedding Dim: {params.get('enc_emb_dim', 'N/A')}")
    print(f"  Decoder Embedding Dim: {params.get('dec_emb_dim', 'N/A')}")
    print(f"  Encoder Heads: {params.get('n_enc_heads', 'N/A')}")
    print(f"  Decoder Heads: {params.get('n_dec_heads', 'N/A')}")
    print(f"  Dropout: {params.get('dropout', 'N/A')}")
    print(f"  Attention Dropout: {params.get('attention_dropout', 'N/A')}")
    
    print("\n‚öôÔ∏è TRAINING PARAMETERS")
    print(f"  Optimizer: {params.get('optimizer', 'N/A')}")
    print(f"  Batch Size: {params.get('batch_size', 'N/A')}")
    print(f"  Eval Batch Size: {params.get('batch_size_eval', 'N/A')}")
    print(f"  Epoch Size: {params.get('epoch_size', 'N/A')}")
    print(f"  Max Epochs: {params.get('max_epoch', 'N/A')}")
    print(f"  Eval Size: {params.get('eval_size', 'N/A')}")
    print(f"  Gradient Clipping: {params.get('clip_grad_norm', 'N/A')}")
    print(f"  Max Length: {params.get('max_len', 'N/A')}")
    print(f"  Max Output Length: {params.get('max_output_len', 'N/A')}")
    
    print("\nüîß OTHER SETTINGS")
    print(f"  FP16: {params.get('fp16', 'N/A')}")
    print(f"  CPU Mode: {params.get('cpu', 'N/A')}")
    print(f"  Multi-GPU: {params.get('multi_gpu', 'N/A')}")
    print(f"  Num Workers: {params.get('num_workers', 'N/A')}")
    
    print("\n" + "=" * 80)
    
    return task_type


task_type = print_model_config(params)

## Convert Metrics to DataFrame

In [None]:
# Convert metrics to DataFrame for easier analysis
df_metrics = pd.DataFrame(metrics)

# Display first few rows
print(f"Total epochs recorded: {len(df_metrics)}")
print("\nFirst few epochs:")
df_metrics.head(10)

## Plot 1: Cross-Entropy Loss

In [None]:
plt.figure(figsize=(12, 6))

if 'valid_arithmetic_xe_loss' in df_metrics.columns:
    plt.plot(df_metrics['epoch'], df_metrics['valid_arithmetic_xe_loss'], 
             marker='o', linewidth=2, markersize=4, label='XE Loss')
    
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('Cross-Entropy Loss', fontsize=12)
    plt.title(f'Validation Cross-Entropy Loss Over Epochs\n({task_type.upper()} task)', fontsize=14, fontweight='bold')
    plt.grid(True, alpha=0.3)
    plt.legend(fontsize=10)
    
    # Add min value annotation
    min_loss = df_metrics['valid_arithmetic_xe_loss'].min()
    min_epoch = df_metrics.loc[df_metrics['valid_arithmetic_xe_loss'].idxmin(), 'epoch']
    plt.axhline(y=min_loss, color='r', linestyle='--', alpha=0.5, label=f'Min: {min_loss:.4f} (epoch {min_epoch})')
    plt.legend(fontsize=10)
else:
    plt.text(0.5, 0.5, 'valid_arithmetic_xe_loss not found in metrics', 
             ha='center', va='center', fontsize=12)

plt.tight_layout()
plt.show()

## Plot 2: Accuracy Metrics (acc, perfect, correct)

In [None]:
plt.figure(figsize=(12, 6))

acc_metrics = ['valid_arithmetic_acc', 'valid_arithmetic_perfect', 'valid_arithmetic_correct']
colors = ['blue', 'green', 'orange']
markers = ['o', 's', '^']

for metric, color, marker in zip(acc_metrics, colors, markers):
    if metric in df_metrics.columns:
        plt.plot(df_metrics['epoch'], df_metrics[metric], 
                marker=marker, linewidth=2, markersize=4, 
                color=color, label=metric.replace('valid_arithmetic_', ''))

plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Accuracy (%)', fontsize=12)
plt.title(f'Validation Accuracy Metrics Over Epochs\n({task_type.upper()} task)', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3)
plt.legend(fontsize=10)
plt.ylim([0, 105])  # Set y-axis range for percentages

plt.tight_layout()
plt.show()

## Plot 3: Class-specific Accuracies (acc_0, acc_1, acc_100)

In [None]:
plt.figure(figsize=(12, 6))

class_metrics = ['valid_arithmetic_acc_0', 'valid_arithmetic_acc_1', 'valid_arithmetic_acc_100']
colors = ['purple', 'red', 'cyan']
markers = ['D', 'v', 'p']
labels = ['Class 0 (Œº=0)', 'Class 1 (Œº=1)', 'Class 100 (Œº=-1)']

has_any_metric = False
for metric, color, marker, label in zip(class_metrics, colors, markers, labels):
    if metric in df_metrics.columns:
        has_any_metric = True
        plt.plot(df_metrics['epoch'], df_metrics[metric], 
                marker=marker, linewidth=2, markersize=4, 
                color=color, label=label)

if has_any_metric:
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('Accuracy (%)', fontsize=12)
    
    if task_type == 'mu':
        plt.title(f'Class-Specific Accuracies Over Epochs (MU task)\nClass 0: Œº=0, Class 1: Œº=1, Class 100: Œº=-1', 
                 fontsize=14, fontweight='bold')
    else:
        plt.title(f'Class-Specific Accuracies Over Epochs ({task_type.upper()} task)', 
                 fontsize=14, fontweight='bold')
    
    plt.grid(True, alpha=0.3)
    plt.legend(fontsize=10)
    plt.ylim([0, 105])  # Set y-axis range for percentages
else:
    plt.text(0.5, 0.5, 'No class-specific accuracy metrics found', 
             ha='center', va='center', fontsize=12)

plt.tight_layout()
plt.show()

## Plot 4: Digit-wise Accuracies (if available)

In [None]:
# Check for digit-wise accuracy metrics
digit_metrics = [col for col in df_metrics.columns if 'acc_d' in col]

if digit_metrics:
    plt.figure(figsize=(12, 6))
    
    for metric in digit_metrics:
        plt.plot(df_metrics['epoch'], df_metrics[metric], 
                marker='o', linewidth=2, markersize=4, 
                label=metric.replace('valid_arithmetic_', ''))
    
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('Accuracy (%)', fontsize=12)
    plt.title(f'Digit-wise Accuracies Over Epochs\n({task_type.upper()} task)', 
             fontsize=14, fontweight='bold')
    plt.grid(True, alpha=0.3)
    plt.legend(fontsize=10)
    plt.ylim([0, 105])
    
    plt.tight_layout()
    plt.show()
else:
    print("No digit-wise accuracy metrics found in the log.")

## Best Metrics Report

In [None]:
def report_best_metrics(df_metrics):
    """
    Report the best (minimum for loss, maximum for accuracies) metrics across all epochs.
    """
    print("=" * 80)
    print("BEST METRICS ACROSS ALL EPOCHS")
    print("=" * 80)
    
    # Cross-Entropy Loss (lower is better)
    if 'valid_arithmetic_xe_loss' in df_metrics.columns:
        min_loss = df_metrics['valid_arithmetic_xe_loss'].min()
        min_loss_epoch = df_metrics.loc[df_metrics['valid_arithmetic_xe_loss'].idxmin(), 'epoch']
        print("\nüìâ CROSS-ENTROPY LOSS (lower is better)")
        print(f"  Best Loss: {min_loss:.6f}")
        print(f"  Achieved at Epoch: {min_loss_epoch}")
    
    # Accuracy metrics (higher is better)
    print("\nüìà ACCURACY METRICS (higher is better)")
    acc_metrics = ['valid_arithmetic_acc', 'valid_arithmetic_perfect', 'valid_arithmetic_correct']
    
    for metric in acc_metrics:
        if metric in df_metrics.columns:
            max_acc = df_metrics[metric].max()
            max_acc_epoch = df_metrics.loc[df_metrics[metric].idxmax(), 'epoch']
            metric_name = metric.replace('valid_arithmetic_', '').upper()
            print(f"\n  {metric_name}:")
            print(f"    Best: {max_acc:.2f}%")
            print(f"    Achieved at Epoch: {max_acc_epoch}")
    
    # Class-specific accuracies
    class_metrics = ['valid_arithmetic_acc_0', 'valid_arithmetic_acc_1', 'valid_arithmetic_acc_100']
    class_labels = ['Class 0 (Œº=0)', 'Class 1 (Œº=1)', 'Class 100 (Œº=-1)']
    
    has_class_metrics = any(metric in df_metrics.columns for metric in class_metrics)
    
    if has_class_metrics:
        print("\nüìä CLASS-SPECIFIC ACCURACIES (higher is better)")
        for metric, label in zip(class_metrics, class_labels):
            if metric in df_metrics.columns:
                max_acc = df_metrics[metric].max()
                max_acc_epoch = df_metrics.loc[df_metrics[metric].idxmax(), 'epoch']
                print(f"\n  {label}:")
                print(f"    Best: {max_acc:.2f}%")
                print(f"    Achieved at Epoch: {max_acc_epoch}")
    
    # Digit-wise accuracies
    digit_metrics = [col for col in df_metrics.columns if 'acc_d' in col]
    
    if digit_metrics:
        print("\nüî¢ DIGIT-WISE ACCURACIES (higher is better)")
        for metric in sorted(digit_metrics):
            max_acc = df_metrics[metric].max()
            max_acc_epoch = df_metrics.loc[df_metrics[metric].idxmax(), 'epoch']
            metric_name = metric.replace('valid_arithmetic_', '').upper()
            print(f"\n  {metric_name}:")
            print(f"    Best: {max_acc:.2f}%")
            print(f"    Achieved at Epoch: {max_acc_epoch}")
    
    # Summary statistics
    print("\n" + "=" * 80)
    print("SUMMARY STATISTICS")
    print("=" * 80)
    print(f"Total Epochs Recorded: {len(df_metrics)}")
    if 'epoch' in df_metrics.columns:
        print(f"Epoch Range: {df_metrics['epoch'].min()} - {df_metrics['epoch'].max()}")
    
    # Latest epoch metrics
    if len(df_metrics) > 0:
        print("\nüìç LATEST EPOCH METRICS:")
        last_row = df_metrics.iloc[-1]
        print(f"  Epoch: {last_row['epoch']}")
        if 'valid_arithmetic_xe_loss' in last_row:
            print(f"  XE Loss: {last_row['valid_arithmetic_xe_loss']:.6f}")
        if 'valid_arithmetic_acc' in last_row:
            print(f"  Accuracy: {last_row['valid_arithmetic_acc']:.2f}%")
        if 'valid_arithmetic_perfect' in last_row:
            print(f"  Perfect: {last_row['valid_arithmetic_perfect']:.2f}%")
    
    print("\n" + "=" * 80)


report_best_metrics(df_metrics)

## Export Metrics to CSV (Optional)

In [None]:
# Uncomment to export metrics to CSV
# output_csv = log_file_path.replace('.log', '_metrics.csv')
# df_metrics.to_csv(output_csv, index=False)
# print(f"Metrics exported to: {output_csv}")

## Detailed Metrics Table

In [None]:
# Display all metrics for inspection
print("\nAll available metric columns:")
print(df_metrics.columns.tolist())

print("\nFull metrics table:")
df_metrics