In [6]:
import jax
import jax.extend
import qiskit
import flax
import tensorflow as tf
import qiskit_aer

print(f"JAX Version: {jax.__version__} (Platform: {jax.extend.backend.get_backend().platform})")
print(f"Flax Version: {flax.__version__}")
print(f"Qiskit Version: {qiskit.__version__}")
print(f"Qiskit Aer Version: {qiskit_aer.__version__}")
print(f"TensorFlow Version: {tf.__version__}")

# Check JAX devices
try:
    print(f"JAX Devices: {jax.devices()}")
except:
    print("JAX Device detection failed")

JAX Version: 0.8.1 (Platform: METAL)
Flax Version: 0.12.1
Qiskit Version: 2.1.2
Qiskit Aer Version: 0.17.2
TensorFlow Version: 2.18.1
JAX Devices: [METAL(id=0)]


In [9]:
from graphviz import Digraph

def create_hybrid_architecture_diagram():
    dot = Digraph(comment='Hybrid Quantum Diffusion Architecture')
    dot.attr(rankdir='LR', splines='ortho', pad='0.5', nodesep='0.8', ranksep='1.0')
    dot.attr('node', shape='box', style='filled,rounded', fontname='Helvetica', fontsize='12')
    
    with dot.subgraph(name='cluster_cpu') as cpu:
        cpu.attr(label='Host (CPU: M4 Pro Performance Cores)\n[Constraint: Single Threaded Execution]', 
                 style='filled', color='#FFF3E0', fontcolor='#E65100', fontsize='14')
        cpu.node_attr.update(fillcolor='white', color='#FF9800', penwidth='2')

        cpu.node('DataLoader', 'Data Loader\n(TensorFlow)\n\nBatch Size: 32\nNo Prefetching')

        with cpu.subgraph(name='cluster_quantum') as q:
            q.attr(label='Quantum Core (quantum_layer.py)', style='dashed', color='#FFB74D', fontcolor='#EF6C00')
            
            q.node('CircuitCache', 'Cached ISA Circuit\n(Pre-transpiled)\n[Template]', shape='note', style='filled', fillcolor='#FFE0B2')
            
            q.node('AerSimulator', 'Qiskit Aer Estimator\n\n[Protected by Lock]\n[RAYON_THREADS=1]')

    with dot.subgraph(name='cluster_gpu') as gpu:
        gpu.attr(label='Device (Apple Metal GPU)\n[Accelerated by JAX/XLA]', 
                 style='filled', color='#E1F5FE', fontcolor='#0277BD', fontsize='14')
        gpu.node_attr.update(fillcolor='white', color='#039BE5', penwidth='2')

        gpu.node('Encoder', 'U-Net Encoder\n(DownBlock, ResNet)\n[JAX]')
        gpu.node('Decoder', 'U-Net Decoder\n(UpBlock, Attention)\n[JAX]')
        gpu.node('Optimization', 'Loss Calculation\n& Param Update\n(Optax)')

    
    dot.edge('DataLoader', 'Encoder', label=' Transfer to GPU', style='dashed', color='#757575')

    dot.edge('Encoder', 'AerSimulator', 
             label=' <Hybrid Bridge>\n1. pure_callback\n2. jax.device_put(cpu)\n3. Padding & Mapping', 
             color='#D50000', penwidth='2.5', fontcolor='#D50000')

    dot.edge('CircuitCache', 'AerSimulator', label=' Load Template', style='dotted', dir='both', arrowtail='odot')

    dot.edge('AerSimulator', 'Decoder', 
             label=' 4. Result (Float32)\n5. Return to GPU', 
             color='#2962FF', penwidth='2.5', fontcolor='#2962FF')

    dot.edge('Decoder', 'Optimization', label=' Forward Output')
    dot.edge('Optimization', 'Encoder', label=' Backward (Gradients)', style='dashed', color='#00C853', fontcolor='#00C853')

    output_path = dot.render('hybrid_architecture', format='png', cleanup=True, view=True)
    print(f"✅ Diagram generated: {output_path}")

create_hybrid_architecture_diagram()



✅ Diagram generated: hybrid_architecture.png


In [None]:
from graphviz import Digraph

def draw_method2_final_step_arrow():
    dot = Digraph(comment='QConvU-Net')
    
    dot.attr(rankdir='TB', splines='ortho', nodesep='1.0', ranksep='0.8')
    dot.attr('node', shape='plain', fontname='Helvetica')
    
    # 2. 색상 정의
    COL_CONV = "#DCEDC8"       # Green
    COL_RESNET = "#E3F2FD"     # Sky Blue
    COL_ATTN = "#FFE0B2"       # Apricot
    COL_QUAN = "#E1BEE7"       # Purple/Pink
    BORDER_QUAN = "#D32F2F"    # Red Border
    BORDER_DEF = "#90CAF9"     # Default Blue Border
    

    def make_row(cells, border=BORDER_DEF):
        rows_html = ""
        for c_type, c_text, c_port in cells:
            bg = "#FFFFFF"
            if c_type == 'Conv': bg = COL_CONV
            elif c_type == 'Res': bg = COL_RESNET
            elif c_type == 'Attn': bg = COL_ATTN
            elif c_type == 'Quan': bg = COL_QUAN
            
            rows_html += f'<TD BGCOLOR="{bg}" PORT="{c_port}">{c_text}</TD>'
            
        return f'''<
        <TABLE BORDER="0" CELLBORDER="1" CELLSPACING="2" CELLPADDING="10" BGCOLOR="white" COLOR="{border}">
            <TR>{rows_html}</TR>
        </TABLE>>'''
    
    # Level 1 (Top)
    # Enc1: [Conv, Res, Attn, Res_Last] -> Res_Last
    dot.node('Enc1', make_row([
        ('Conv', 'Conv', 'c_conv'), 
        ('Res', 'ResNet', 'c_r1'), 
        ('Attn', 'Attention\nBlock', 'c_attn'), 
        ('Res', 'ResNet', 'c_out')  # Output Port
    ]), xlabel='28x28x10')
    
    # Dec1: [Res_First, Attn, Res, Conv] -> Res_First
    dot.node('Dec1', make_row([
        ('Res', 'ResNet', 'c_in'),  # Input Port
        ('Attn', 'Attention\nBlock', 'c_attn'), 
        ('Res', 'ResNet', 'c_r2'), 
        ('Conv', 'Conv', 'c_conv')
    ]), xlabel='28x28x10')

    # Level 2 (Middle)
    # Enc2: [Quan_First, Attn, Quan_Last]
    dot.node('Enc2', make_row([
        ('Quan', 'QResNet', 'c_in'),  # Input Port
        ('Attn', 'Attention\nBlock', 'c_attn'), 
        ('Quan', 'QResNet', 'c_out')  # Output Port
    ], border=BORDER_QUAN), xlabel='14x14x20')
    
    # Dec2
    dot.node('Dec2', make_row([
        ('Res', 'ResNet', 'c_in'),    # Input Port
        ('Attn', 'Attention\nBlock', 'c_attn'), 
        ('Res', 'ResNet', 'c_out')    # Output Port
    ]), xlabel='14x14x20')

    # Level 3 (Bottom)
    dot.node('Enc3', make_row([
        ('Res', 'ResNet', 'c_in'), 
        ('Attn', 'Attention\nBlock', 'c_attn'), 
        ('Res', 'ResNet', 'c_out')
    ]), xlabel='7x7x30')
    
    dot.node('Dec3', make_row([
        ('Res', 'ResNet', 'c_in'), 
        ('Attn', 'Attention\nBlock', 'c_attn'), 
        ('Res', 'ResNet', 'c_out')
    ]), xlabel='7x7x30')

    # Bottleneck
    dot.node('Bottleneck', make_row([
        ('Quan', 'QResNet', 'c_in'), 
        ('Attn', 'Attention\nBlock', 'c_attn'), 
        ('Quan', 'QResNet', 'c_out')
    ], border=BORDER_QUAN), xlabel='2x2x40')

    with dot.subgraph() as s:
        s.attr(rank='same'); s.node('Enc1'); s.node('Dec1')
    with dot.subgraph() as s:
        s.attr(rank='same'); s.node('Enc2'); s.node('Dec2')
    with dot.subgraph() as s:
        s.attr(rank='same'); s.node('Enc3'); s.node('Dec3')

    
    # [Encoder Path] Downsample

    dot.edge('Enc1:c_out:s', 'Enc2:c_in:n', color='black')
    dot.edge('Enc2:c_out:s', 'Enc3:c_in:n', color='black')
    dot.edge('Enc3:c_out:s', 'Bottleneck:c_in:n', color='black')

    # [Decoder Path] Upsample
    
    # Bottleneck -> Dec3
    dot.edge('Dec3:c_in:s', 'Bottleneck:c_out:n', dir='back', color='black')
    
    # Dec3 -> Dec2
    dot.edge('Dec2:c_in:s', 'Dec3:c_out:n', dir='back', color='black')
    
    # Dec2 -> Dec1
    dot.edge('Dec1:c_in:s', 'Dec2:c_out:n', dir='back', color='black')

    # [Skip Connections]
    dot.edge('Enc1:c_out:e', 'Dec1:c_in:w', style='solid', constraint='false', arrowsize='0.6')
    dot.edge('Enc2:c_out:e', 'Dec2:c_in:w', style='solid', constraint='false', arrowsize='0.6')
    dot.edge('Enc3:c_out:e', 'Dec3:c_in:w', style='solid', constraint='false', arrowsize='0.6')

    dot.node('InArrow', '', shape='none', height='0', width='0')
    dot.node('OutArrow', '', shape='none', height='0', width='0')
    
    output_path = dot.render('QConvUnet', format='png', cleanup=True)
    print(f"Diagram saved to: {output_path}")

try:
    draw_method2_final_step_arrow()
except Exception as e:
    print(f"Error: {e}")

Diagram saved to: QConvUnet.png
