# Let's explore the RFX MDSplus tree

## Importing and setting up stuff

In [None]:
import MDSplus as mds
import numpy as np
import matplotlib.pyplot as plt
import sys, random
from tqdm import tqdm
print(f'Python version: {sys.version}')
print(f'MDSplus version: {mds.__version__}')
np.set_printoptions(precision=3, suppress=True)

In [None]:
#color the terminal output
def pick_random_color():
    return '\033[38;5;{}m'.format(random.randint(8, 230))
ENDC = '\033[0m'
ERR = '\033[91m'+ 'ERR: '
OK = '\033[92m' 
WARN = '\033[93m'+ 'WARN: '

In [None]:
# define the shot number and tree
SHOT = 30810
rfx = mds.Tree('rfx', SHOT, 'readonly') # open the tree read-only
from convert_to_hdf5 import SEG_FAULT_NODES

## Traversing the tree

In [None]:
# traverse the tree, use MAX_DEPTH to limit the depth of the tree to traverse
# othwerwise the script will run for about 10 minutes
MAX_DEPTH = 3# 13 # maximum depth of the tree to traverse
COLORS = [pick_random_color() for _ in range(MAX_DEPTH)]

In [None]:
usage_depth, usage_breadth = {},{}
total_nodes_depth, total_nodes_breadth = [],[]

def traverse_tree_depth_first(max_depth, node, level=0, path='', node_type='child'):
    try: 
        if level >= max_depth: return # stop if the maximum depth is reached
        if node.getFullPath() in SEG_FAULT_NODES: return # skip the nodes that cause segfault
        if node_type == 'child': node_name = node.node_name.upper()
        elif node_type == 'member': node_name = node.node_name.lower()
        else: raise
        path = path + '/' + COLORS[level] + node_name + ENDC # add the node name
        total_nodes_depth.append(node) # add the node to the list
        print(f'{path}:{node.decompile()}') 
        # get the usage/type of the node
        try: usage_depth[str(node.usage)] += 1
        except: usage_depth[str(node.usage)] = 1
        # go through the children and members of the node
        for child in node.getChildren(): # get the children of the node
            traverse_tree_depth_first(max_depth, child, level + 1, path, 'child')
        for member in node.getMembers(): # get the members of the node
            traverse_tree_depth_first(max_depth, member, level + 1, path, 'member')
    except Exception as e:
        print(path + 'ERR:' + str(e))
        pass

# do the same but without recursion
def traverse_tree_breadth_first(max_depth, head_node):
    curr_nodes = [head_node]
    for d in range(max_depth):
        print('Depth:', d)
        next_nodes = []
        for node in curr_nodes:
            try:
                if node.getFullPath() in SEG_FAULT_NODES: continue # skip the nodes that cause segfault
                preprint = COLORS[d] + "   " * d + node.node_name + ENDC
                print(f'{preprint}:{node.decompile()}') # print the node
                total_nodes_breadth.append(node) # add the node to the list
                # get the usage/type of the node
                try: usage_breadth[str(node.usage)] += 1
                except: usage_breadth[str(node.usage)] = 1
                # get the children of the node
                for child in node.getChildren():
                    next_nodes.append(child)
                # get the members of the node
                for member in node.getMembers():
                    next_nodes.append(member)
            except: pass
        curr_nodes = next_nodes
        
# get the top node of the tree
head_node = rfx.getNode('\\TOP.RFX.MHD') # get the top node
# # test the functions, uncomment to run
traverse_tree_depth_first(MAX_DEPTH, head_node) # traverse the tree depth-first
traverse_tree_breadth_first(MAX_DEPTH, head_node) # traverse the tree breadth-first

print(f'Total nodes depth: {len(total_nodes_depth)}') # 96771, 96750
print(f'Total nodes breadth: {len(total_nodes_breadth)}') # 96771, 96750

In [None]:
print(f'Usage depth: {usage_depth}')
print(f'Usage breadth: {usage_breadth}')

previous cell full depth: 'STRUCTURE': 8776, 'SUBTREE': 78, 'DEVICE': 642, 'ACTION': 1098, 'NUMERIC': 47760, 'TEXT': 17269, 'SIGNAL': 20904, 'ANY': 29, 'AXIS': 215

In [None]:
print(f'top nodes: {[n.node_name for n in head_node.getChildren()]}')

## Exploring Signals

In [None]:
search_space = '\\TOP.RFX.MHD.***' # *** means all nodes at this level
# search_space = '\\TOP.RFX.EDA.***' # * means all nodes at this level
# search_space = '\\TOP.RFX.***' # whole rfx tree
signal_nodes = rfx.getNodeWild(search_space, 'Signal') # get all nodes with the name 'Signal'
print(f'Found {len(signal_nodes)} of the type Signal in the search space {search_space}')

In [None]:
# filter out the nodes without the data
data_signals = []
for node in tqdm(signal_nodes, leave=False):
    try: data = node.data(); data_signals.append(node)
    except: pass
print(f'Found {len(data_signals)}/{len(signal_nodes)} signals with data')

In [None]:
# keep only the signals with raw data
raw_signals = []
for node in tqdm(signal_nodes, leave=False):
    try: data = node.raw_of().data(); raw_signals.append(node)
    except: pass
print(f'Found {len(raw_signals)}/{len(data_signals)} signals with raw data')

In [None]:
# extract data from the signals and plot them
MAX_LOAD = 3 #10 #np.inf
MAX_LOAD = min(MAX_LOAD, len(raw_signals))
# select MAX_LOAD random signals
signals = random.sample(raw_signals, MAX_LOAD)
for node in (signals):
    signal = node.data()
    times = node.dim_of().data()
    unit = node.getUnits()
    full_path = node.getFullPath()
    try: node_help = node.getHelp()
    except: node_help = ''
    if signal.shape != times.shape:
        print(f'{full_path} has mismatched signal and time shapes')
        continue
    # plot the signal
    plt.figure()
    plt.plot(times, signal)
    plt.title(f'{full_path} [{unit}]\n{node_help}')
    plt.xlabel('Time [s]')
    plt.ylabel('Signal')
    plt.show()

## Exploring Text

In [None]:
text_nodes = rfx.getNodeWild(search_space, 'Text') # get all the 'TEXT' nodes
print(f'Found {len(text_nodes)} of the type Text in the search space {search_space}')
# print all the text nodes
for node in text_nodes:
    try: print(f'{node.getFullPath()}={node.data()}')
    except: pass

## Exploring Times

In [None]:
import MDSplus as mds
import numpy as np
import matplotlib.pyplot as plt
import sys
from tqdm import tqdm
print(f'Python version: {sys.version}')
print(f'MDSplus version: {mds.__version__}')
np.set_printoptions(precision=8, suppress=True)

In [None]:
# define the shot number and tree
SHOT = 30810
rfx = mds.Tree('rfx', SHOT, 'readonly') # open the tree read-only
from convert_to_hdf5 import SEG_FAULT_NODES
MAX_DEPTH = 13 # maximum depth of the tree to traverse

In [None]:
# function to understand if the sampling time is constant, if it's not, it keeps the section with
# the lower sampling time
def plot_signal(t,d,dts):
    fig, ax1 = plt.subplots()
    color = 'tab:red'
    ax1.set_xlabel('Time [s]')
    ax1.set_ylabel('Data', color=color)
    ax1.plot(t, d, color=color)
    ax1.tick_params(axis='y', labelcolor=color)
    ax2 = ax1.twinx()
    color = 'tab:blue'
    ax2.set_ylabel('Time differences', color=color)
    ax2.scatter(t[1:], dts, color=color, s=1)
    ax2.tick_params(axis='y', labelcolor=color)
    fig.tight_layout()
    plt.show()
    
def check_and_keep_low_sampling_time(t, data, precision=5):
    dts = np.diff(t) # get the time differences
    dtsr = np.round(dts, precision) # round the time difference to the nearest 1e-6
    unique_dtsr = np.unique(dtsr) # find the unique time differences
    if not np.all(unique_dtsr >= 0):
        plot_signal(t,data,dtsr)
        raise ValueError(f'Negative time differences: {unique_dtsr}')
    elif not np.all(unique_dtsr >= 1e-9):
        plot_signal(t,data,dtsr)
        raise ValueError(f'Very small time differences: {[f"{d:.2e}" for d in unique_dtsr if d < 1e-9]}, len(t): {len(t)}')
    if len(unique_dtsr) == 1: # if the time differences are the same
        it = t # input time
        ridxs = np.arange(len(it)) # return indexes
    elif len(unique_dtsr) == 2:
        min_dt = np.min(unique_dtsr) # use the smallest dt
        ridxs = np.where(dtsr == min_dt)[0]
        # check that all the idxs are adjacent to each other
        if not np.all(np.diff(ridxs) == 1):
            plot_signal(t,data,dtsr)
            raise ValueError(f'Non-adjacent indexes: {ridxs}')
        it = t[ridxs] # input time
    else: 
        #plot a single figure with data against time and dtsr against time on the same y-axis
        plot_signal(t,data,dtsr)
        raise ValueError(f'Invalid time differences: {unique_dtsr}')
    rt = np.linspace(it[0], it[-1], len(it)) # return time
    std_t = np.std(np.diff(rt))
    assert std_t < 1e-8, f'Non-constant time: {std_t}'
    return rt, ridxs

In [None]:
times, datas, names = [], [], []

def traverse_get_times(max_depth, head_node):
    curr_nodes = [head_node]
    for d in range(max_depth):
        print(f'Depth: {d}')
        next_nodes = []
        # for node in tqdm(curr_nodes, ncols=80, desc=f'Depth:{d}'):
        for node in curr_nodes:
            if node.getFullPath() in SEG_FAULT_NODES: continue # skip the nodes that cause segfault
            try: #get the time vector of the node
                timev = node.dim_of().data()
                datav = node.data()
                assert len(timev) == len(datav), f'len(timev)={len(timev)} != len(datav)={len(datav)}'
                try:
                    # assert timev.ndim == 1, f'ignored: timev.ndim={timev.ndim}'
                    # assert len(timev) > 1000, f'ignored: len(timev)={len(timev)}'
                    assert timev.ndim == 1, ''
                    assert len(timev) > 1000, ''
                    times.append(timev), datas.append(datav), names.append(node.getFullPath())
                    t, idxs = check_and_keep_low_sampling_time(timev, datav)
                except Exception as e:
                    if str(e) != '': print(f'{node.getFullPath()} ERR: {e}')
                    pass
            except: pass
            try: next_nodes.extend(node.getChildren()) 
            except: pass
            try: next_nodes.extend(node.getMembers())
            except: pass
        curr_nodes = next_nodes

# traverse_get_times(MAX_DEPTH, rfx.getNode('\\TOP.RFX')) # get the top node)
traverse_get_times(5, rfx.getNode('\\TOP.RFX')) # get the top node)

In [None]:
# analyze the times
lengths = np.array([len(t) for t in times]) # get the lengths of the time vectors
times_diff = [np.diff(t) for t in times] # get the differences between the time vectors
times_diff_stds = [np.std(d) for d in times_diff] # get the std of the differences
stds = np.array(times_diff_stds)

assert len(datas) == len(times), f'len(datas)={len(datas)} != len(times)={len(times)}'
for d, t in zip(datas, times):
    assert len(t) == len(d), f'len(t)={len(t)} != len(d)={len(d)}'
    assert t.shape[0] == d.shape[0], f't.shape[0]={t.shape[0]} != d.shape[0]={d.shape[0]}'

mean_stds = np.mean(stds)
std_stds = np.std(stds)
min_stds, max_stds = np.min(stds), np.max(stds)
mean_lengths = np.mean(lengths)
std_lengths = np.std(lengths)
min_lengths, max_lengths = np.min(lengths), np.max(lengths)

print(f'Mean stds: {mean_stds}, std stds: {std_stds}, min stds: {min_stds}, max stds: {max_stds}')
print(f'Mean lengths: {mean_lengths}, std lengths: {std_lengths}, min lengths: {min_lengths}, max lengths: {max_lengths}')

In [None]:
plt.figure(figsize=(10, 6))
plt.grid(True, linestyle='--', alpha=0.5)
plt.hist(stds, bins=np.logspace(-6, -1, 20), edgecolor='black')
plt.xscale('log')
plt.yscale('log')
plt.xlabel('Standard deviation of the time differences')
plt.ylabel('Counts')
plt.title('Histogram of the standard deviation of the time differences')
plt.tight_layout()
plt.show()

In [None]:
# pick bad times: high std of diff
# %matplotlib widget
%matplotlib inline
bad_times, bad_datas = [], []
for t, d, std in zip(times, datas, times_diff_stds):
    if std > 1e-6 and len(t)>100000: bad_times.append(t), bad_datas.append(d)
print(f'Found {len(bad_times)} bad times')
#select N random indices
N = 5
indices = np.random.choice(len(bad_times), N)
for i in indices:
    t,d = bad_times[i], bad_datas[i] # get the time and data vectors
    full_path = names[i] # get the full path of the node
    print(f'{full_path}\n, t.shape: {t.shape}, d.shape: {d.shape}')
    diff = np.abs(np.diff(t)) # get the differences between the time vector, the deltas
    diff = np.append(diff, 0) # add
    
    # #plot histogram if diff
    # plt.figure(figsize=(6, 3))
    # bins = np.logspace(-6, -1, 20)
    # counts, bins, _ = plt.hist(diff, bins =bins, edgecolor='black')
    # plt.ylabel('Counts')
    # plt.title('Histogram of the time differences')
    # plt.grid(True, linestyle='--', alpha=0.5)
    # for i, count in enumerate(counts):
    #     plt.text(bins[i], count, str(int(count)), ha='center', va='bottom')
    # plt.xscale('log')

    #plot the data vector and time difference
    fig, ax1 = plt.subplots(figsize=(5, 3))
    ax2 = ax1.twinx()
    # plot data vector
    ax1.plot(t, d, color='blue')
    ax1.set_xlabel('Time')
    ax1.set_ylabel('Data', color='blue')
    # plot time difference
    ax2.plot(t, diff, color='red')
    ax2.set_ylabel('Time Difference', color='red')
    ax1.set_title(full_path)
    plt.grid(True, linestyle='--', alpha=0.5)
    
    plt.show()

In [None]:
# get all times starts and ends
tstarts = np.array([t[0] for t in times])
tends = np.array([t[-1] for t in times])
tlengths = np.array([len(t) for t in times])
tdeltas = np.array([np.median(np.diff(t)) for t in times])
tdurations = tends - tstarts

# filter the signals in standard and strange:
strange_idxs, reasons, ok_idxs = [], [], []
for i, (t, d, n, ts, te, tl, td, tdur) in enumerate(zip(times, datas, names, tstarts, tends, tlengths, tdeltas, tdurations)):
        r = []
        if te > 30: r.append('end > 30')
        elif td > 1e-2: r.append('delta > 1e-2'); print(f'{n} delta > 1e-2')
        if tl < 1000: r.append('l < 1000'); print(f'{n} l < 1000')
        if tdur > 60: r.append('dur > 60')
        if len(r) > 0: strange_idxs.append(i), reasons.append(r)
        else: ok_idxs.append(i)
print(f'Found {len(strange_idxs)}/{len(times)} strange signals')

strange_idxs_picked = np.random.choice(strange_idxs, 20)
# strange_idxs_picked = strange_idxs
for j, si in enumerate(strange_idxs_picked):
    d, t, n = datas[si], times[si], names[si]
    reas = reasons[j]
    #plot the signal
    plt.figure(figsize=(5, 2))
    plt.plot(t, d)
    plt.title(f'{n}\n{reas}')
    plt.xlabel('Time [s]')
    plt.ylabel('Signal')
    plt.show()

tstarts, tends, tlengths, tdeltas, tdurations = tstarts[ok_idxs], tends[ok_idxs], tlengths[ok_idxs], tdeltas[ok_idxs], tdurations[ok_idxs]
print(f'Found {len(tstarts)} good signals')

nbins = 20
# create a histogram for each of the time properties
fig, axs = plt.subplots(5, 1, figsize=(10, 12))
counts, bins, _ = axs[0].hist(tstarts, bins=nbins, edgecolor='black')
for i, count in enumerate(counts):
    axs[0].text(bins[i], count, str(int(count)), ha='center', va='bottom')
axs[0].set_xlabel('Time Start')
axs[0].set_ylabel('Counts')
axs[0].grid(True, linestyle='--', alpha=0.5)
counts, bins, _ = axs[1].hist(tends, bins=nbins, edgecolor='black') 
for i, count in enumerate(counts):
    axs[1].text(bins[i], count, str(int(count)), ha='center', va='bottom')
axs[1].set_xlabel('Time End')
axs[1].grid(True, linestyle='--', alpha=0.5)
counts, bins, _ = axs[2].hist(tlengths, bins=nbins, edgecolor='black')
for i, count in enumerate(counts):
    axs[2].text(bins[i], count, str(int(count)), ha='center', va='bottom')
axs[2].set_xlabel('Time Length')
axs[2].grid(True, linestyle='--', alpha=0.5)
counts, bins, _ = axs[3].hist(tdeltas, bins=nbins, edgecolor='black')
for i, count in enumerate(counts):
    axs[3].text(bins[i], count, str(int(count)), ha='center', va='bottom')
axs[3].set_xlabel('Time Delta')
axs[3].grid(True, linestyle='--', alpha=0.5)
counts, bins, _ = axs[4].hist(tdurations, bins=nbins, edgecolor='black')
for i, count in enumerate(counts):
    axs[4].text(bins[i], count, str(int(count)), ha='center', va='bottom')
axs[4].set_xlabel('Time Duration')
axs[4].grid(True, linestyle='--', alpha=0.5)
plt.tight_layout()
plt.show()

## Eplore RAW Signals

In [1]:
import MDSplus as mds
import numpy as np
import matplotlib.pyplot as plt
import sys
from tqdm import tqdm
print(f'Python version: {sys.version}')
print(f'MDSplus version: {mds.__version__}')
np.set_printoptions(precision=8, suppress=True)

# define the shot number and tree
SHOT = 30810
rfx = mds.Tree('rfx', SHOT, 'readonly') # open the tree read-only
from convert_to_hdf5 import SEG_FAULT_NODES
MAX_DEPTH = 13 # maximum depth of the tree to traverse

Python version: 3.7.12 | packaged by conda-forge | (default, Oct 26 2021, 06:08:53) 
[GCC 9.4.0]
MDSplus version: 1.0.0


In [3]:
raw_signals = []

def traverse_get_raw(max_depth, head_node):
    curr_nodes = [head_node]
    for d in range(max_depth):
        next_nodes = []
        for node in tqdm(curr_nodes, ncols=80, desc=f'Depth:{d}'):
        # for node in curr_nodes:
            # get the children and members of the node
            try: next_nodes.extend(node.getChildren()) 
            except: pass
            try: next_nodes.extend(node.getMembers())
            except: pass
            full_path = node.getFullPath()
            if full_path in SEG_FAULT_NODES: continue # skip the nodes that cause segfault
            if not ('RAW' in full_path.upper()): continue # skip the nodes that are not raw
            try: #get the data and time of the node
                timev = node.dim_of().data()
                datav = node.data()
                assert len(timev) == len(datav), f'len(timev)={len(timev)} != len(datav)={len(datav)}'
                assert timev.ndim == 1, f'ignored: timev.ndim={timev.ndim}'
                assert len(timev) > 1000, f'ignored: len(timev)={len(timev)}'
            except: pass
        curr_nodes = next_nodes

traverse_get_raw(5, rfx.getNode('\\TOP.RFX')) # get the top node)

print(f'Found {len(raw_signals)} raw signals')

Depth: 0


Depth:0: 100%|████████████████████████████████████| 1/1 [00:00<00:00, 65.78it/s]


Depth: 1


Depth:1: 100%|████████████████████████████████████| 8/8 [00:00<00:00, 55.22it/s]


Depth: 2


Depth:2: 100%|██████████████████████████████████| 88/88 [00:01<00:00, 83.02it/s]


Depth: 3


Depth:3: 100%|███████████████████████████████| 623/623 [00:06<00:00, 102.79it/s]


Depth: 4


Depth:4: 100%|█████████████████████████████| 2422/2422 [00:21<00:00, 114.19it/s]

Found 0 raw signals



