In [84]:
class Foo:
    def __init__(self):
        self._value = 1
    @property
    def value(self):
        print('i am here too')
        return self._value

In [85]:
foo = Foo()

In [86]:
a = foo.value

i am here too


In [89]:
a

1

In [None]:
class MetaclassFoo(type):
    def __new__(cls, name, bases, dct):
        print('youve been metad')
        init = dct.get('__init__')
        if init:
            # Wrap the class's __init__ method to call super().__init__() first
            def wrapped_init(self, *args, **kwargs):
                print('All good')
                init(self, *args, **kwargs)  # Call the subclass's __init__
            dct['__init__'] = wrapped_init
        return super().__new__(cls, name, bases, dct)

class Foo(metaclass = MetaclassFoo):
    def whatever(self):
        print('something')



youve been metad


In [76]:
class A(Foo):
    def __init__(self):
        pass

youve been metad


In [78]:
a = A()

All good


In [34]:
class Foo:
    def __new__(cls):
        print('new_created')
        return super().__new__(cls)

    def whatever(self):
        print('something')

class MetaclassFoo(type, Foo):
    def __new__(cls, name, bases, dct):
        print('new_created')
        return super().__new__(cls, name, bases, dct)

In [35]:
class A(Foo):
    def method(self):
        print('method executed')

In [36]:
a = A()

new_created


In [40]:
class B(Foo, metaclass = MetaclassFoo):
    def method(self):
        print('method executed')

new_created


In [41]:
b = B()

new_created


In [81]:
class Test:
    def test(func):
        def wrapper(self, *args, **kwargs):
            print('wrapped it')
            func(self, *args, **kwargs)
        return wrapper

    @test
    def method(self, n):
        print('oh man', n)

class SubTest(Test):
    @Test.test
    def method(self, n):
        print('Am I wrapped')

class SubSubTest(SubTest):
    @SubSubTest.test
    def method(self, n):
        print('Am I wrapped')

In [60]:
t = Test()
t.method(1)

wrapped it
oh man 1


In [82]:
subt = SubTest()
subt.method(1)

wrapped it
Am I wrapped


In [83]:
subsubt = SubSubTest()
subsubt.method(1)

wrapped it
Am I wrapped


In [1]:
import time
import asyncio
import zmq
import threading
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colors
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.axes import Axes
from matplotlib.gridspec import GridSpec

# Implementation questions - how to add live plotting with minimal effort

## Thread safety

The biggest challenge with live plotting is the racing conditions when accessing shared resources such as data_handles from the QM, or even other objects that are shared between different processes. We can solve this with the following Abstract classes:

In [2]:
class ThreadSafeAbstractMeas:
    def __init__(self):
        self._lock = threading.RLock()  # Lock for thread safety
        self._data_handles = None  # Critical resource

    @property
    def data_handles(self):
        """Thread-safe getter for data_handles."""
        with self._lock:
            return self._data_handles

    @data_handles.setter
    def data_handles(self, value):
        """Thread-safe setter for data_handles."""
        with self._lock:
            self._data_handles = value

    @staticmethod
    def thread_safe(func):
        """Decorator to make methods thread-safe."""
        def wrapper(self, *args, **kwargs):
            with self._lock:
                return func(self, *args, **kwargs)
        return wrapper

This guy has:
- a lock which prevents simultaneous access to _data_handles
- data_handles property which "regulates" access to _data_handles
- a thread_safe decorator which wraps a function we want to run in a lock such that we are sure that our function is thread safe (this is a last resort option since this can lead to long locks causing some parts of the system to wait)

However, what this guy needs additionally is the following metaclass:

In [3]:
class AutoSuperInitMeta(type):
    def __new__(cls, name, bases, dct):
        global live
        # Check if the class defines its own __init__ method
        init = dct.get('__init__')
        if init:
            # Wrap the class's __init__ method to call super().__init__() first
            def wrapped_init(self, *args, **kwargs):
                for base in bases:
                    if hasattr(base, '__init__'):
                        base.__init__(self)  # Call the parent class's __init__
                init(self, *args, **kwargs)  # Call the subclass's __init__
            dct['__init__'] = wrapped_init

        # Also, we can wrap the execute measure function to call live.start(self)
        execute_meas = dct.get('execute_meas')
        if execute_meas:
            def wrapped_execute_meas(self, *args, **kwargs):
                execute_meas(self)
                live.start(self)
            dct['execute_meas'] = wrapped_execute_meas

        return super().__new__(cls, name, bases, dct)

This metaclass wraps the init and execute_meas methods, such that when we run init of a measurement, the live plotting is stopped (maybe this is not needed) and when we run execute_meas, the  live plotting automatically starts. With this, our abstract measurement class would look like the following:

In [4]:
class ThreadSafeMeasurement(metaclass=AutoSuperInitMeta):
    def __init__(self):
        global live
        print(live)
        live.stop()


        self._lock = threading.Lock()  # Lock for thread safety
        self._data_handles = None  # Critical resource


    @property
    def data_handles(self):
        """Thread-safe getter for data_handles."""
        with self._lock:
            return self._data_handles

    @data_handles.setter
    def data_handles(self, value):
        """Thread-safe setter for data_handles."""
        with self._lock:
            self._data_handles = value

    @staticmethod
    def thread_safe(func):
        """Decorator to make methods thread-safe."""
        def wrapper(self, *args, **kwargs):
            with self._lock:
                return func(self, *args, **kwargs)
        return wrapper


# Measurement classes - Example

Let's create a class that mimics measurement class in Quantrol.

In [5]:
class Example1(ThreadSafeMeasurement):
    def __init__(self,
                 w = 2*np.pi/2,
                 k = 0.2,
                 t_max = 10,
                 n_points = 100,
                 n_params = (50, 100),
                 n_iterations = 1_000,
                 k_x = 1.5,
                 k_y = 4,
                ):


        nx = n_params[1]
        ny = n_params[0]

        self.n_iterations = n_iterations
        self.iter = 0
        self.flag = True

        self.x = np.linspace(0, t_max, n_points)
        self.ps = (1-np.exp(-k*self.x)*np.sin(w*self.x))/2

        self.y = np.zeros(n_points)

        self.xs = np.linspace(0, 5, nx)
        self.ys = np.linspace(0, 10, ny)

        x, y = np.meshgrid(self.xs, self.ys)
        self.ps_z = (1-np.exp(-k*x)*np.sin(w*x))*y/2

        self.z = np.zeros((ny, nx))

        self.rng = np.random.default_rng()

    def math_func(self):
        return np.array(self.rng.uniform(0,1, size = max(self.ps.shape)) < self.ps, dtype = float)

    def map_func(self):
        return np.array(np.reshape(self.rng.uniform(0,1, size = self.ps_z.shape), self.ps_z.shape) > self.ps_z, dtype = float)

    def _get_data_handles(self):
        while self.iter < (self.n_iterations) and self.flag:
            self.iter += 1
            self.y = self.y + self.math_func()
            self.z = self.z + self.map_func()
            yield self.y/self.iter, self.z/self.iter

    def execute_meas(self):
        self.data_handles = self._get_data_handles()
        return self.data_handles

    def analysis(self):
        y, z = next(self.data_handles)
        signal = {
            'line' : y,
            'map' : z
        }
        fit = {
            'line' : self.ps,
        }

        return (signal, fit)

    def plot(self, signal, fit, live = False, save = False):
        fig, axs = plt.subplots(1, 2, figsize = (10,5))


        y = signal['line']
        z = signal['map']

        y_fit = fit['line']


        axs[0].scatter(self.x, y)
        axs[0].plot(self.x, y_fit)

        axs[1].pcolorfast(self.xs, self.ys, z)

        plt.show()

        if save:
            print('Saved')
        else:
            pass

    def getData(self):
        signal, fit = self.analysis()

        data_pack = {
            "iter" : self.iter,
            "name" : "Example",
            "n_iterations" : self.n_iterations,
            }

        data_pack["layout"] = {
                "Plot1": {
                    'content' : {
                        'signal' : {'type' : 'ScatterPlot'},
                        'fit' : {'type' : 'LinePlot'}
                    },
                    'x_label' : "t [ms]",
                    'y_label' : r'$p_e$',
                    'loc' : [0, 0, 1, 1]
                },
                "Plot2": {
                    'content' : {'map' : {'type' : 'HeatMap'}},
                    'loc' : [0, 1, 1, 1]
                },
        }

        data_pack["data"] = {
            'Plot1' : {
                'signal' : {
                    'x' : self.x,
                    'y' : signal['line'],
                },
                'fit': {
                    'x' : self.x,
                    'y' : fit['line'],
                },
            },
            'Plot2' : {
                'map' : {
                    'x' : self.xs,
                    'y' : self.ys,
                    'z' : signal['map']
                },
            }
        }
        return data_pack

In [6]:
class Example2(ThreadSafeMeasurement):
    def __init__(self,
                 N = 10,
                 t_start = 0,
                 t_stop = 0.5,
                 n_points = 100,
                 n_params = (51, 51),
                 n_iterations = 1_000,
                 k_x = 1.5,
                 k_y = 1.5,
                ):


        self.t_start = t_start
        self.t_stop = t_stop
        self.N = N
        self.y = np.array([])

        nx = n_params[1]
        ny = n_params[0]

        self.n_iterations = n_iterations
        self.iter = 0
        self.flag = True

        self.xs = np.linspace(-10, 10, nx)
        self.ys = np.linspace(-10, 10, ny)

        x, y = np.meshgrid(self.xs, self.ys)


        x_0, y_0 = 3, 3

        self.ps_z = np.exp(-((x-x_0)/k_x)**2 - ((y-y_0)/k_y)**2) + np.exp(-((x+x_0)/k_x)**2 - ((y+y_0)/k_y)**2)
        self.ps_z = self.ps_z / np.max(self.ps_z)
        self.z = np.zeros((ny, nx))

        self.rng = np.random.default_rng()

    def clicks(self):
        return self.rng.uniform(self.t_start, self.t_stop, size = self.rng.poisson(lam = self.N))

    def map_func(self):
        return np.array(np.reshape(self.rng.uniform(0,1, size = self.ps_z.shape), self.ps_z.shape) > self.ps_z, dtype = float)

    def _get_data_handles(self):
        while self.iter < (self.n_iterations) and self.flag:
            self.iter += 1
            self.y = np.append(self.y, self.clicks() + self.t_start + (self.iter - 1)*(self.t_stop - self.t_start))
            self.z = self.z + self.map_func()
            yield self.y, self.z/self.iter

    def execute_meas(self):
        self.data_handles = self._get_data_handles()
        return self.data_handles

    def analysis(self):
        y, z = next(self.data_handles)
        #y, z = next(data_handles)

        signal = {
            'clicks' : y[np.argmin(np.abs(y - (y[-1]-1))):],
            'map' : z
        }

        return signal

    def plot(self, signal, live = False, save = False):
        fig, axs = plt.subplots(1, 2, figsize = (10,5))


        y = signal['clicks']
        z = signal['map']



        axs[0].vlines(y, ymin = 0, ymax = 1)

        axs[1].pcolorfast(self.xs, self.ys, z)

        plt.show()

        if save:
            print('Saved')
        else:
            pass

    def getData(self):
        signal = self.analysis()

        data_pack = {
            "iter" : self.iter,
            "name" : "Example",
            "n_iterations" : self.n_iterations,
            }

        data_pack["layout"] = {
            "Plot1" : {
                'content' : {
                    'clicks' : {
                        'type' : 'Counts'
                    }
                },
                'x_label' : "t [ms]",
                'loc' : [0, 0, 1, 1]
            },
            "Plot2" : {
                'content' : {
                    'map' : {
                        'type' : 'HeatMap'
                    }
                },
                'loc' : [0, 1, 1, 1]
            },
        }

        data_pack["data"] = {
            'Plot1' : {
                'clicks' : {
                    'x' : signal['clicks']
                },
            },
            'Plot2' : {
                'map' : {
                    'x' : self.xs,
                    'y' : self.ys,
                    'z' : signal['map']
                },
            }
        }

        return data_pack

# Defining live plotting class

This class is responsible for sending the live plot data to the plotting widget. It creates a thread for data and a thread for control (not implemented yet), and sends data/listens to the control over the localhost ports.

In [16]:
DATA_PORT = 5561
CONTROL_PORT = 5562

In [17]:
class LivePlotting:
    def __init__(self):
        self.context = zmq.Context()

        self.socket = self.context.socket(zmq.PUB)
        self.socket.setsockopt(zmq.SNDHWM, 1)
        self.socket.bind(f"tcp://*:{DATA_PORT}")

        self.control_context = zmq.Context()
        self.control_socket = self.control_context.socket(zmq.REP)
        self.control_socket.bind(f"tcp://*:{CONTROL_PORT}")

        self._running = False
        self._save = False
        self.signal = None

        self._stop_event = threading.Event()
        self._thread = None
        self._lock = threading.Lock()

    def _fetch_and_send_data(self, meas):
        while not self._stop_event.is_set():
            with self._lock:
                try:
                    data = meas.getData().copy()
                except Exception as e:
                    self.stop()
                    # Display the exception
                    print(f"Exception occurred: {e}")

                iter = data.get('iter', 0)
                self.signal = data.get('signal', None)

            self.socket.send_pyobj(data, flags=zmq.NOBLOCK)


            if iter == meas.n_iterations:
                self._stop_event.set()

            time.sleep(self.refresh_rate)  # Adjust the sleep time as needed

    def handle_control(self):
        poller = zmq.Poller()
        poller.register(self.control_socket, zmq.POLLIN)

        while not self._stop_event.is_set():
            socks = dict(poller.poll(100))  # Poll sockets with a 100-ms timeout
            if self.control_socket in socks and socks[self.control_socket] == zmq.POLLIN:
                try:
                    message = self.control_socket.recv_string(flags=zmq.NOBLOCK)
                    if message == "SAVE":
                        self.control_socket.send_string("OK")
                        save_thread = threading.Thread(target=self.meas.plot, args= self.signal,
                                                        kwargs= {'show_fig' : False,'save' : True})
                        save_thread.start()
                except zmq.ZMQError as e:
                    if e.errno == zmq.ETERM:
                        break

    def start_control_thread(self):
        self._control_thread = threading.Thread(target=self.handle_control)
        self._control_thread.start()

    def start(self, meas, refresh_rate = 1):
        '''
        Receives as an argument a measurement and a refresh rate in seconds.
        '''
        if refresh_rate>=0.2:
            self.refresh_rate = refresh_rate
        else:
            self.refresh_rate = 0.2

        self.stop()

        self.meas = meas
        self.running = True

        self._stop_event.clear()
        self._thread = threading.Thread(target=self._fetch_and_send_data, args=(meas,))
        self._thread.start()

        self.start_control_thread()

    def stop(self):
        if self._thread is not None and self._thread.is_alive():
            self._stop_event.set()
            self._thread.join()
            self._control_thread.join()
            self.running = False

    def releaseResources(self):
        self.socket.close()
        self.context.term()

        self.control_socket.close()
        self.control_context.term()



# DEMO

First, we need to instantiate a LivePlotting class.

In [18]:
live = LivePlotting()

Let's now start the first measurement, Example1

In [28]:
meas = Example1()
data = meas.execute_meas()

<__main__.LivePlotting object at 0x000001ABB3CF6970>


Hopefully, the plotting started automatically, and now we can switch to another measurement, and the plotting should follow.

In [31]:
meas = Example2()
data = meas.execute_meas()

<__main__.LivePlotting object at 0x000001ABB3CF6970>


Important thing is that we can still work as before, we can still shift enter like monkeys and play with cells!!

In [27]:
signal, fit = meas.analysis()
print(meas.iter)
#meas.plot(signal, fit)

StopIteration: 

In order to stoo the live plotting, we can write the following command:

In [13]:
live.stop()