In [4]:
import plotly.graph_objects as go
import ipywidgets as widgets
from IPython.display import display, clear_output
import pandas as pd
import json
from datetime import datetime
import os

class PersistentSankey:
    def __init__(self):
        # Data persistence widgets
        self.filename = widgets.Text(
            description='Filename(without extension):',
            placeholder='sankey_data',
            layout=widgets.Layout(width='300px')
        )
        
        self.file_format = widgets.Dropdown(
            options=['json', 'csv'],
            value='json',
            description='Format:',
            layout=widgets.Layout(width='200px')
        )
        
        self.save_btn = widgets.Button(
            description='Save Data',
            button_style='primary',
            icon='save'
        )
        
        self.load_btn = widgets.Button(
            description='Load Data',
            button_style='warning',
            icon='upload'
        )
        
        # Horizontal line
        self.separator = widgets.HTML(
            value='<hr style="height:2px;border-width:0;color:gray;background-color:gray;margin:20px 0;">'
        )
        
        # Input container
        self.input_container = widgets.VBox([])
        self.flow_rows = []
        
        # Action buttons
        self.add_btn = widgets.Button(
            description='Add New Flow',
            button_style='info',
            icon='plus',
            layout=widgets.Layout(width='150px')
        )
        
        self.create_btn = widgets.Button(
            description='Create Diagram',
            button_style='success',
            icon='chart-line',
            layout=widgets.Layout(width='150px')
        )
        
        # Output area
        self.output = widgets.Output()
        
        # Bind events
        self.add_btn.on_click(self.add_flow_row)
        self.create_btn.on_click(self.create_diagram)
        self.save_btn.on_click(self.save_data)
        self.load_btn.on_click(self.load_data)
        
        # Main layout
        self.main_container = widgets.VBox([

            # Separator
            self.separator,
            
            # Data persistence controls
            widgets.HBox([
                self.filename,
                self.file_format,
                self.save_btn,
                self.load_btn
            ], layout=widgets.Layout(margin='10px 0px')),
            
            # Separator
            self.separator,
            
            # Flow inputs
            self.input_container,

            # Separator
            self.separator,
            
            # Action buttons in single row
            widgets.HBox([
                self.add_btn,
                widgets.HTML(value='<div style="width:20px"></div>'),  # Spacer
                self.create_btn
            ], layout=widgets.Layout(
                display='flex',
                justify_content='center',
                margin='20px 0px'
            )),
            
            # Output area
            self.output
        ])
        
        # Add initial row
        self.add_flow_row(None)

    def create_flow_row(self):
        """Create a single flow input row with remove button"""
        source = widgets.Text(
            placeholder='Enter source',
            description='Source:',
            layout=widgets.Layout(width='200px')
        )
        
        target = widgets.Text(
            placeholder='Enter target',
            description='Target:',
            layout=widgets.Layout(width='200px')
        )
        
        value = widgets.FloatText(
            placeholder='Enter value',
            description='Value:',
            layout=widgets.Layout(width='200px')
        )
        
        remove_btn = widgets.Button(
            description='Remove',
            button_style='danger',
            icon='trash',
            layout=widgets.Layout(width='100px')
        )
        
        row = widgets.HBox([source, target, value, remove_btn], 
                          layout=widgets.Layout(margin='5px 0px'))
        
        def remove_row(b):
            if len(self.flow_rows) > 1:
                self.flow_rows.remove(row)
                self.input_container.children = tuple(self.flow_rows)
        
        remove_btn.on_click(remove_row)
        return row, source, target, value

    def add_flow_row(self, b):
        """Add a new flow input row"""
        row, _, _, _ = self.create_flow_row()
        self.flow_rows.append(row)
        self.input_container.children = tuple(self.flow_rows)

    def get_current_data(self):
        """Extract current data from input fields"""
        data = []
        for row in self.flow_rows:
            source = row.children[0].value
            target = row.children[1].value
            value = row.children[2].value
            if source and target and value is not None:
                data.append({
                    'source': source,
                    'target': target,
                    'value': value
                })
        return data

    def save_data(self, b):
        """Save current data to file"""
        data = self.get_current_data()
        if not data:
            with self.output:
                clear_output()
                print("Error: No data to save")
                return

        filename = self.filename.value or f'sankey_data_{datetime.now().strftime("%Y%m%d_%H%M%S")}'
        format = self.file_format.value
        
        try:
            if format == 'json':
                with open(f'{filename}.json', 'w') as f:
                    json.dump(data, f, indent=4)
            else:  # csv
                df = pd.DataFrame(data)
                df.to_csv(f'{filename}.csv', index=False)
            
            with self.output:
                clear_output()
                print(f"Data saved successfully to {filename}.{format}")
        except Exception as e:
            with self.output:
                clear_output()
                print(f"Error saving data: {str(e)}")

    def load_data(self, b):
        """Load data from file"""
        filename = self.filename.value
        format = self.file_format.value
        
        if not filename:
            with self.output:
                clear_output()
                print("Error: Please provide a filename")
                return
        
        try:
            if format == 'json':
                with open(f'{filename}.json', 'r') as f:
                    data = json.load(f)
            else:  # csv
                df = pd.read_csv(f'{filename}.csv')
                data = df.to_dict('records')
            
            # Clear existing rows
            self.flow_rows = []
            
            # Create new rows with loaded data
            for item in data:
                row, source, target, value = self.create_flow_row()
                source.value = item['source']
                target.value = item['target']
                value.value = item['value']
                self.flow_rows.append(row)
            
            self.input_container.children = tuple(self.flow_rows)
            
            with self.output:
                clear_output()
                print(f"Data loaded successfully from {filename}.{format}")
        except Exception as e:
            with self.output:
                clear_output()
                print(f"Error loading data: {str(e)}")

    def create_sankey(self, source, target, value):
        """Create Sankey diagram"""
        labels = list(set([*source, *target]))
        source_idx = [labels.index(s) for s in source]
        target_idx = [labels.index(t) for t in target]
        
        max_value = max(value)
        color_scale = [f'rgba(0,128,255,{0.3 + 0.7 * (v/max_value)})' for v in value]
        
        fig = go.Figure(data=[go.Sankey(
            node = dict(
                pad = 15,
                thickness = 20,
                line = dict(color = "black", width = 0.5),
                label = labels,
                color = "lightblue"
            ),
            link = dict(
                source = source_idx,
                target = target_idx,
                value = value,
                color = color_scale
            )
        )])
        
        fig.update_layout(
            title_text="Sankey Diagram",
            font_size=12,
            height=600
        )
        
        return fig

    def create_diagram(self, b):
        """Handle diagram creation button click"""
        with self.output:
            clear_output()
            
            data = self.get_current_data()
            if not data:
                print("Error: Please add at least one complete flow (Source, Target, and Value)")
                return
            
            sources = [d['source'] for d in data]
            targets = [d['target'] for d in data]
            values = [d['value'] for d in data]
            
            fig = self.create_sankey(sources, targets, values)
            fig.show()

# Create and display the interactive tool
sankey_tool = PersistentSankey()
display(sankey_tool.main_container)

VBox(children=(HTML(value='<hr style="height:2px;border-width:0;color:gray;background-color:gray;margin:20px 0…