In [1]:
from functools import partial
from importlib import reload
import unittest

from silq.pulses import PulseSequence, DCPulse, TriggerPulse, SinePulse
from silq.instrument_interfaces import get_instrument_interface

In [2]:
%%javascript

Jupyter.keyboard_manager.command_shortcuts.add_shortcut('r', {
    help : 'run cell',
    help_index : 'zz',
    handler : function (event) {
        IPython.notebook.restart_run_all();
        return false;
    }}
);

<IPython.core.display.Javascript object>

In [3]:
# General functions
def print_function(*args, **kwargs):
    print('args={args}, kwargs={kwargs}'.format(args=args, kwargs=kwargs))

class AddChannelFunctions:
    def __init__(self, channels, functions):
        self.channels = channels
        self.functions = functions
    
    def __call__(self, cls):
        def print_function(*args, **kwargs):
            print('args={args}, kwargs={kwargs}'.format(args=args, kwargs=kwargs))
            
        for channel in self.channels:
            for function in self.functions:
                print_function_targeted = partial(print_function, ch=channel, function=function)
                exec("cls.{ch}_{fn} = print_function_targeted".format(ch=str(channel), fn=function))
#         cls.ch1_trig_in = print_function
        return cls

In [None]:
# Test PulseSequence
class TestPulseSequence(unittest.TestCase):
    def setUp(self):
        self.pulse_sequence = PulseSequence()
        
    def test_add_remove_pulse(self):
        pulse = DCPulse(name='dc', amplitude=1.5, duration=10, t_start=0)
        self.pulse_sequence.add(pulse)
        self.assertIn(pulse, self.pulse_sequence)

        # Remove pulses
        self.pulse_sequence.clear()
        self.assertEqual(len(self.pulse_sequence.pulses), 0)
        
    def test_sort(self):
        pulse1 = DCPulse(name='dc1', amplitude=1.5, duration=10, t_start=1)
        pulse2 = DCPulse(name='dc2', amplitude=1.5, duration=10, t_start=0)
        self.pulse_sequence.add(pulse1)
        self.pulse_sequence.add(pulse2)
        self.assertEqual(pulse2, self.pulse_sequence[0])


In [None]:
# Test ArbStudio

channels = ['ch1', 'ch2', 'ch3', 'ch4']
functions = ['trigger_source', 'trigger_mode', 'add_waveform', 'sequence']

@AddChannelFunctions(channels, functions)
class MockArbStudio:
    def __init__(self):
        pass
    
class TestArbStudio(unittest.TestCase):
    def setUp(self):
        self.pulse_sequence = PulseSequence()
        self.arbstudio = MockArbStudio()
        self.arbstudio_interface = get_instrument_interface(self.arbstudio)
        
    def test_pulse_implementation(self):
        sine_pulse = SinePulse(t_start=0, duration=10, frequency=1e6, amplitude=1)
        self.assertIsNone(self.arbstudio_interface.get_pulse_implementation(sine_pulse))
        
        DC_pulse = DCPulse(t_start=0, duration=10, amplitude=1)
        self.assertIsNotNone(self.arbstudio_interface.get_pulse_implementation(DC_pulse))
        DC_pulse.amplitude=3
        self.assertIsNone(self.arbstudio_interface.get_pulse_implementation(DC_pulse))
        
    def test_ELR_programming(self):
        empty_pulse = DCPulse(name='empty', t_start=0, duration=10, amplitude=1.5)
        load_pulse = DCPulse(name='load', t_start=10, duration=10, amplitude=-1.5)
        read_pulse = DCPulse(name='read', t_start=20, duration=10, amplitude=0)
        pulses = [empty_pulse, load_pulse, read_pulse]
        for pulse in pulses:
            self.arbstudio_interface.pulse_sequence.add(pulse)
        self.assertEqual(pulses, self.arbstudio_interface.pulse_sequence.pulses)

    

In [21]:
# Run tests    
suite_pulse_sequence = unittest.TestLoader().loadTestsFromModule(TestPulseSequence())
suite_arbstudio = unittest.TestLoader().loadTestsFromModule(TestArbStudio())
suite = unittest.TestSuite([suite_pulse_sequence, suite_arbstudio])
unittest.TextTestRunner().run(suite)

....
----------------------------------------------------------------------
Ran 4 tests in 0.003s

OK


<unittest.runner.TextTestResult run=4 errors=0 failures=0>