In [None]:
!pip install seaborn

import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter, defaultdict
import numpy as np
import os

def ensure_plot_dir():
    """Ensure the plot directory exists"""
    if not os.path.exists('plot'):
        os.makedirs('plot')

def load_combinations(file_path):
    with open(file_path, 'r') as f:
        return json.load(f)

def get_successful_models_count(combinations):
    return sum(1 for model_data in combinations.values() if model_data)

def get_package_frequency(combinations):
    package_freq = defaultdict(int)
    for model_data in combinations.values():
        if model_data:
            for combo in model_data:
                for package in combo.keys():
                    package_freq[package] += 1
    return package_freq

def get_version_distribution(combinations):
    version_dist = defaultdict(lambda: defaultdict(int))
    for model_data in combinations.values():
        if model_data:
            for combo in model_data:
                for package, version in combo.items():
                    version_dist[package][version] += 1
    return version_dist

def create_iteration_progress_plot(iteration_files):
    """Plot the number of successful models across iterations"""
    counts = []
    for i, file in enumerate(iteration_files, 1):
        combinations = load_combinations(file)
        counts.append(get_successful_models_count(combinations))
    
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, len(counts) + 1), counts, marker='o')
    plt.title('Number of Successful Models Across Iterations')
    plt.xlabel('Iteration')
    plt.ylabel('Number of Successful Models')
    plt.grid(True)
    plt.savefig('plot/iteration_progress.png')
    plt.close()

def create_package_heatmap(iteration_files):
    """Create a heatmap of package co-occurrence"""
    # Create a set of all packages
    all_packages = set()
    for file in iteration_files:
        combinations = load_combinations(file)
        for model_data in combinations.values():
            if model_data:
                for combo in model_data:
                    all_packages.update(combo.keys())
    
    # Create co-occurrence matrix
    n_packages = len(all_packages)
    cooccurrence = np.zeros((n_packages, n_packages))
    package_list = list(all_packages)
    
    for file in iteration_files:
        combinations = load_combinations(file)
        for model_data in combinations.values():
            if model_data:
                for combo in model_data:
                    for i, pkg1 in enumerate(package_list):
                        for j, pkg2 in enumerate(package_list):
                            if pkg1 in combo and pkg2 in combo:
                                cooccurrence[i, j] += 1
    
    # Normalize the co-occurrence matrix by dividing each row by its diagonal value
    normalized_cooccurrence = np.zeros_like(cooccurrence)
    for i in range(n_packages):
        if cooccurrence[i, i] > 0:  # Avoid division by zero
            normalized_cooccurrence[i, :] = cooccurrence[i, :] / cooccurrence[i, i]
    
    # Create heatmap
    plt.figure(figsize=(15, 12))
    sns.heatmap(normalized_cooccurrence, 
                xticklabels=package_list,
                yticklabels=package_list,
                cmap='YlOrRd',
                annot=True,
                fmt='.2f',  # Show 2 decimal places
                vmin=0,
                vmax=1)  # Set the range from 0 to 1
    plt.title('Normalized Package Co-occurrence Heatmap')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig('plot/package_heatmap.png')
    plt.close()

def create_version_distribution_plot(iteration_files):
    """Create a plot showing version distribution for major packages"""
    # Combine all iterations
    all_combinations = {}
    for file in iteration_files:
        combinations = load_combinations(file)
        for model_id, model_data in combinations.items():
            if model_data:
                if model_id not in all_combinations:
                    all_combinations[model_id] = []
                all_combinations[model_id].extend(model_data)
    
    version_dist = get_version_distribution(all_combinations)
    
    # Focus on major packages
    major_packages = ['transformers', 'torch', 'numpy', 'datasets', 'tokenizers']
    
    plt.figure(figsize=(15, 8))
    for i, package in enumerate(major_packages, 1):
        if package in version_dist:
            versions = list(version_dist[package].keys())
            counts = list(version_dist[package].values())
            plt.subplot(2, 3, i)
            plt.bar(versions, counts)
            plt.title(f'{package} Version Distribution')
            plt.xticks(rotation=45, ha='right')
            plt.ylabel('Count')
    
    plt.tight_layout()
    plt.savefig('plot/version_distribution.png')
    plt.close()

def analyze_version_conflicts(iteration_files):
    """Analyze and report version conflicts"""
    all_combinations = {}
    for file in iteration_files:
        combinations = load_combinations(file)
        for model_id, model_data in combinations.items():
            if model_data:
                if model_id not in all_combinations:
                    all_combinations[model_id] = []
                all_combinations[model_id].extend(model_data)
    
    version_dist = get_version_distribution(all_combinations)
    
    # Find packages with multiple versions
    conflicts = {}
    for package, versions in version_dist.items():
        if len(versions) > 1:
            conflicts[package] = versions
    
    # Write conflicts to a file
    with open('plot/version_conflicts.txt', 'w') as f:
        f.write("Version Conflicts Analysis\n")
        f.write("=======================\n\n")
        for package, versions in conflicts.items():
            f.write(f"\n{package}:\n")
            for version, count in versions.items():
                f.write(f"  {version}: {count} occurrences\n")

def main():
    # Ensure plot directory exists
    ensure_plot_dir()
    
    # List of iteration files
    iteration_files = [
        'successful_combinations_text-classification 1.json',
        'successful_combinations_text-classification 2.json',
        'successful_combinations_text-classification 3.json',
        'successful_combinations_text-classification 4.json',
        'successful_combinations_text-classification 5.json',
        'successful_combinations_text-classification 6.json',
        'successful_combinations_text-classification 7.json',
        'successful_combinations_text-classification 8.json'
    ]
    
    # Create all visualizations
    create_iteration_progress_plot(iteration_files)
    create_package_heatmap(iteration_files)
    create_version_distribution_plot(iteration_files)
    analyze_version_conflicts(iteration_files)
    
    print("Analysis complete. Generated files in plot/ directory:")
    print("1. iteration_progress.png - Shows progress across iterations")
    print("2. package_heatmap.png - Shows package co-occurrence")
    print("3. version_distribution.png - Shows version distribution for major packages")
    print("4. version_conflicts.txt - Detailed analysis of version conflicts")

if __name__ == "__main__":
    main() 

Defaulting to user installation because normal site-packages is not writeable
Collecting seaborn
  Downloading seaborn-0.13.2-py3-none-any.whl (294 kB)
     |████████████████████████████████| 294 kB 5.3 MB/s            
Installing collected packages: seaborn
Successfully installed seaborn-0.13.2
Analysis complete. Generated files in plot/ directory:
1. iteration_progress.png - Shows progress across iterations
2. package_heatmap.png - Shows package co-occurrence
3. version_distribution.png - Shows version distribution for major packages
4. version_conflicts.txt - Detailed analysis of version conflicts


In [1]:
import json
import networkx as nx
import matplotlib.pyplot as plt
from collections import defaultdict
import numpy as np
import seaborn as sns

def load_all_combinations(iteration_files):
    """Load and combine all iteration files"""
    all_combinations = {}
    for file in iteration_files:
        with open(file, 'r') as f:
            data = json.load(f)
            for model_id, model_data in data.items():
                if model_data:
                    if model_id not in all_combinations:
                        all_combinations[model_id] = []
                    all_combinations[model_id].extend(model_data)
    return all_combinations

def create_compatibility_graph(combinations):
    """Create a graph showing package version compatibility"""
    G = nx.Graph()
    
    # Add nodes and edges based on successful combinations
    for model_data in combinations.values():
        for combo in model_data:
            # Add all package versions as nodes
            for pkg1, ver1 in combo.items():
                node1 = f"{pkg1}=={ver1}"
                G.add_node(node1, package=pkg1, version=ver1)
                
                # Add edges between all pairs of packages in this combination
                for pkg2, ver2 in combo.items():
                    if pkg1 < pkg2:  # Avoid duplicate edges
                        node2 = f"{pkg2}=={ver2}"
                        G.add_edge(node1, node2, weight=1)
    
    return G

def plot_compatibility_network(G, output_file):
    """Plot the compatibility network"""
    plt.figure(figsize=(20, 20))
    
    # Use spring layout for better visualization
    pos = nx.spring_layout(G, k=1, iterations=50)
    
    # Draw nodes
    nx.draw_networkx_nodes(G, pos, 
                          node_color='lightblue',
                          node_size=2000,
                          alpha=0.7)
    
    # Draw edges
    nx.draw_networkx_edges(G, pos,
                          edge_color='gray',
                          width=1,
                          alpha=0.5)
    
    # Draw labels
    nx.draw_networkx_labels(G, pos,
                           font_size=8,
                           font_weight='bold')
    
    plt.title("Package Version Compatibility Network", pad=20)
    plt.axis('off')
    plt.tight_layout()
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    plt.close()

def create_compatibility_matrix(G, output_file):
    """Create a heatmap of package version compatibility"""
    # Get all nodes
    nodes = list(G.nodes())
    
    # Create compatibility matrix
    n = len(nodes)
    matrix = np.zeros((n, n))
    
    # Fill matrix with edge weights
    for i, node1 in enumerate(nodes):
        for j, node2 in enumerate(nodes):
            if G.has_edge(node1, node2):
                matrix[i, j] = G[node1][node2]['weight']
    
    # Create heatmap
    plt.figure(figsize=(15, 15))
    sns.heatmap(matrix,
                xticklabels=nodes,
                yticklabels=nodes,
                cmap='YlOrRd',
                annot=True,
                fmt='g')
    plt.title("Package Version Compatibility Matrix")
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    plt.close()

def analyze_compatibility_stats(G):
    """Analyze and print compatibility statistics"""
    # Group nodes by package
    package_groups = defaultdict(list)
    for node in G.nodes():
        package = G.nodes[node]['package']
        package_groups[package].append(node)
    
    # Print statistics
    print("\nCompatibility Analysis:")
    print("=====================")
    for package, versions in package_groups.items():
        print(f"\n{package}:")
        print(f"  Number of versions: {len(versions)}")
        print("  Versions:")
        for version in sorted(versions):
            print(f"    - {version}")
            # Print compatible packages
            compatible = [n for n in G.neighbors(version)]
            print(f"      Compatible with: {', '.join(compatible)}")

def main():
    # List of iteration files
    iteration_files = [
        f'successful_combinations {i}.json'
        for i in range(1, 8)
    ]
    
    # Load and combine all combinations
    combinations = load_all_combinations(iteration_files)
    
    # Create compatibility graph
    G = create_compatibility_graph(combinations)
    
    # Create visualizations
    plot_compatibility_network(G, 'plot/compatibility_network.png')
    create_compatibility_matrix(G, 'plot/compatibility_matrix.png')
    
    # Print statistics
    analyze_compatibility_stats(G)
    
    print("\nVisualization complete. Generated files:")
    print("1. plot/compatibility_network.png - Network visualization of package compatibility")
    print("2. plot/compatibility_matrix.png - Heatmap of package compatibility")

if __name__ == "__main__":
    main() 

ModuleNotFoundError: No module named 'seaborn'