In [None]:
!pip install MRzeroCore
!pip install toolapi

In [None]:
# @title Download sequence and phantom data

!wget -q --show-progress https://github.com/mrx-org/toolapi-py/releases/download/v0.1.0/gre.seq
!wget -q --show-progress https://github.com/mrx-org/toolapi-py/releases/download/v0.1.0/brainweb-subj05-3T.json
!wget -q --show-progress https://github.com/mrx-org/toolapi-py/releases/download/v0.1.0/brainweb-subj05.nii.gz
!wget -q --show-progress https://github.com/mrx-org/toolapi-py/releases/download/v0.1.0/brainweb-subj05_dB0.nii.gz
!wget -q --show-progress https://github.com/mrx-org/toolapi-py/releases/download/v0.1.0/brainweb-subj05_B1+.nii.gz

In [None]:
import MRzeroCore as mr0
import torch
import matplotlib.pyplot as plt
from time import time

import toolapi

In [None]:
# @title Helpers to convert between `toolapi` and `MRzeroCore`
def phantom_dict_to_toolapi(phantom: mr0.TissueDict) -> toolapi.value.MultiTissuePhantom:
    tissues = list(phantom.values())
    voxel_size = (tissues[0].size / torch.as_tensor(tissues[0].PD.shape)).tolist()

    return toolapi.value.MultiTissuePhantom(
        "AASinc",  # voxel shape type
        voxel_size,  # voxel shape data
        voxel_size,  # grid spacing
        list(tissues[0].PD.shape),  # grid size
        # We assume that all tissues have the same B1 and coil_sens
        tissues[0].B1.reshape(tissues[0].B1.shape[0], -1).tolist(),
        tissues[0].coil_sens.reshape(tissues[0].coil_sens.shape[0], -1).tolist(),
        [
            (
                tissue.PD.flatten().tolist(),
                tissue.B0.flatten().tolist(),
                toolapi.value.TissueProperties(
                    float(tissue.T1[tissue.PD > 0].mean()),
                    float(tissue.T2[tissue.PD > 0].mean()),
                    float(tissue.T2dash[tissue.PD > 0].mean()),
                    float(tissue.D[tissue.PD > 0].mean()),
                )
            )
            for tissue in tissues
        ]
    )


def to_instant_events(seq: mr0.Sequence) -> toolapi.value.EventSeq:
    ie_seq = []

    Event = toolapi.value.Event
    for rep in seq:
        ie_seq.append(Event.Pulse(rep.pulse.angle, rep.pulse.phase))
        for ev in range(rep.event_count):
            ie_seq.append(Event.Fid([
                rep.gradm[ev, 0],
                rep.gradm[ev, 1],
                rep.gradm[ev, 2],
                rep.event_time[ev]
            ]))
            if rep.adc_usage[ev] != 0:
                ie_seq.append(Event.Adc(torch.pi / 2 - rep.adc_phase[ev]))

    ie_seq = toolapi.value.EventSeq(ie_seq)
    return ie_seq

In [None]:
# @title Define tools
def on_message(msg):
    print(f"\r > {msg}", end="")
    return True

def load_seq(path, exact_trajectories=True):
    with open(path) as f:
        file_content = f.read()
    result = toolapi.call(
        "wss://tool-seqloader-flyio.fly.dev/tool",
        on_message,
        seq_file=file_content,
        exact_trajectory=exact_trajectories
    )
    print("\n --- done ---")
    return result["seq"]

# Define the simulation tool
def sim_spdg(sequence, phantom):
    result = toolapi.call(
        "wss://tool-spdg-flyio.fly.dev/tool",
        on_message,
        sequence=sequence,
        phantom=phantom,
    )
    print("\n --- done ---")
    return result["signal"]

In [None]:
# @title Load phantom
seq_tool = load_seq("gre.seq")
seq_mr0 = to_instant_events(mr0.Sequence.import_file("gre.seq", exact_trajectories=False))

# Colab can't handle the phantom at full resolution; reduce data a bit
config = mr0.NiftiPhantom.load("brainweb-subj05-3T.json")
config.tissues.pop("fat")
for tissue in config.tissues.values():
    tissue.B1_tx = [1.0]
    tissue.B1_rx = [1.0]

phantom = mr0.TissueDict.load(".", config)
phantom = phantom.interpolate(64, 64, 64).slices([30])
phantom = phantom_dict_to_toolapi(phantom)

In [None]:
# @title Load sequence
seq_tool = load_seq("gre.seq", exact_trajectories=True)
seq_mr0 = to_instant_events(mr0.Sequence.import_file("gre.seq", exact_trajectories=True))

In [None]:
# @title compare parsed sequences for differences
print(f"event count: tool={len(seq_tool.events)}, mr0={len(seq_tool.events)}")

rf_angles = []
rf_phases = []
kt_x = []
kt_y = []
kt_z = []
kt_t = []
adc_phases = []

for ev_tool, ev_mr0 in zip(seq_tool.events, seq_mr0.events):
    if ev_tool.variant == 'Pulse':
        rf_angles.append((ev_tool.fields['angle'], ev_mr0.fields['angle']))
        rf_phases.append((ev_tool.fields['phase'], ev_mr0.fields['phase']))
    elif ev_tool.variant == 'Fid':
        kt_x.append((ev_tool.fields['kt'][0], ev_mr0.fields['kt'][0]))
        kt_y.append((ev_tool.fields['kt'][1], ev_mr0.fields['kt'][1]))
        kt_z.append((ev_tool.fields['kt'][2], ev_mr0.fields['kt'][2]))
        kt_t.append((ev_tool.fields['kt'][3], ev_mr0.fields['kt'][3]))
    elif ev_tool.variant == 'Adc':
        adc_phases.append((ev_tool.fields['phase'], ev_mr0.fields['phase']))


print(f"max diff in angle: {max(abs(x[0] - x[1]) for x in rf_angles)}")
print(f"max diff in phase: {max(abs(x[0] - x[1]) for x in rf_phases)}")
print(f"max diff in kt_x: {max(abs(x[0] - x[1]) for x in kt_x)}")
print(f"max diff in kt_y: {max(abs(x[0] - x[1]) for x in kt_y)}")
print(f"max diff in kt_z: {max(abs(x[0] - x[1]) for x in kt_z)}")
print(f"max diff in kt_t: {max(abs(x[0] - x[1]) for x in kt_t)}")
print(f"max diff in adc phase: {max(abs(x[0] - x[1]) for x in adc_phases)}")

# plt.figure()
# plt.plot([x[0] - x[1] for x in rf_angles], label="angle")
# plt.plot([x[0] - x[1] for x in rf_phases], label="phase")
# plt.plot([x[0] - x[1] for x in kt_x], label="kt_x")
# plt.plot([x[0] - x[1] for x in kt_y], label="kt_y")
# plt.plot([x[0] - x[1] for x in kt_z], label="kt_z")
# plt.plot([x[0] - x[1] for x in kt_t], label="kt_t")
# plt.legend()
# plt.grid()
# plt.show()

In [None]:
# @title Run simulation on fly.io
start = time()
signal_tool = torch.tensor(sim_spdg(seq_tool, phantom))[0, :]
print(f"took {time() - start:.3} s")
start = time()
signal_mr0 = torch.tensor(sim_spdg(seq_mr0, phantom))[0, :]
print(f"took {time() - start:.3} s")

reco_tool = torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(signal_tool.reshape(256, 256))))
reco_mr0 = torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(signal_mr0.reshape(256, 256))))

# Small difference due to small deviations (order of 1e-7) between adc phases
plt.figure()
plt.suptitle("Difference")
plt.subplot(121)
plt.imshow((reco_tool - reco_mr0).abs(), origin="lower", vmin=0)
plt.axis("off")
plt.colorbar()
plt.subplot(122)
plt.imshow((reco_tool - reco_mr0).angle(), origin="lower", vmin=-torch.pi, vmax=torch.pi, cmap="twilight")
plt.axis("off")
plt.colorbar()
plt.show()

for signal, title in [(signal_tool, "Tool"), (signal_mr0, "MR0")]:
    kspace = signal.reshape(256, 256)
    reco = torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(kspace)))

    plt.figure()
    plt.suptitle(title)
    plt.subplot(211)
    plt.plot(signal.abs())
    plt.grid()
    plt.subplot(223)
    plt.imshow(reco.abs(), origin="lower", vmin=0)
    plt.axis("off")
    plt.subplot(224)
    plt.imshow(reco.angle(), origin="lower", vmin=-torch.pi, vmax=torch.pi, cmap="twilight")
    plt.axis("off")
    plt.show()
