In [None]:
!pip install MRzeroCore
!pip install toolapi

In [None]:
# @title Download sequence and phantom data

!wget -q --show-progress https://github.com/mrx-org/toolapi-py/releases/download/v0.1.0/gre.seq
!wget -q --show-progress https://github.com/mrx-org/toolapi-py/releases/download/v0.1.0/brainweb-subj05-3T.json
!wget -q --show-progress https://github.com/mrx-org/toolapi-py/releases/download/v0.1.0/brainweb-subj05.nii.gz
!wget -q --show-progress https://github.com/mrx-org/toolapi-py/releases/download/v0.1.0/brainweb-subj05_dB0.nii.gz
!wget -q --show-progress https://github.com/mrx-org/toolapi-py/releases/download/v0.1.0/brainweb-subj05_B1+.nii.gz

In [None]:
import MRzeroCore as mr0
import torch
import matplotlib.pyplot as plt
from time import time

import toolapi

In [None]:
# @title Helpers to convert between `toolapi` and `MRzeroCore`
def phantom_dict_to_toolapi(phantom: mr0.TissueDict) -> toolapi.value.MultiTissuePhantom:
    tissues = list(phantom.values())
    voxel_size = (tissues[0].size / torch.as_tensor(tissues[0].PD.shape)).tolist()

    return toolapi.value.MultiTissuePhantom(
        "AASinc",  # voxel shape type
        voxel_size,  # voxel shape data
        voxel_size,  # grid spacing
        list(tissues[0].PD.shape),  # grid size
        # We assume that all tissues have the same B1 and coil_sens
        tissues[0].B1.reshape(tissues[0].B1.shape[0], -1).tolist(),
        tissues[0].coil_sens.reshape(tissues[0].coil_sens.shape[0], -1).tolist(),
        [
            (
                tissue.PD.flatten().tolist(),
                tissue.B0.flatten().tolist(),
                toolapi.value.TissueProperties(
                    float(tissue.T1[tissue.PD > 0].mean()),
                    float(tissue.T2[tissue.PD > 0].mean()),
                    float(tissue.T2dash[tissue.PD > 0].mean()),
                    float(tissue.D[tissue.PD > 0].mean()),
                )
            )
            for tissue in tissues
        ]
    )


def to_instant_events(seq: mr0.Sequence) -> toolapi.value.EventSeq:
    ie_seq = []

    Event = toolapi.value.Event
    for rep in seq:
        ie_seq.append(Event.Pulse(rep.pulse.angle, rep.pulse.phase))
        for ev in range(rep.event_count):
            ie_seq.append(Event.Fid([
                rep.gradm[ev, 0],
                rep.gradm[ev, 1],
                rep.gradm[ev, 2],
                rep.event_time[ev]
            ]))
            if rep.adc_usage[ev] != 0:
                ie_seq.append(Event.Adc(torch.pi / 2 - rep.adc_phase[ev]))

    ie_seq = toolapi.value.EventSeq(ie_seq)
    return ie_seq

In [None]:
# @title Load sequence and phantom
seq = mr0.Sequence.import_file("gre.seq")
seq = to_instant_events(seq)

# Colab can't handle the phantom at full resolution; reduce data a bit
config = mr0.NiftiPhantom.load("brainweb-subj05-3T.json")
config.tissues.pop("fat")
for tissue in config.tissues.values():
    tissue.B1_tx = [1.0]
    tissue.B1_rx = [1.0]

phantom = mr0.TissueDict.load(".", config)
phantom = phantom.interpolate(64, 64, 64).slices([30])
phantom = phantom_dict_to_toolapi(phantom)

In [None]:
# @title Run simulation on fly.io

# Define the simulation tool
def sim_spdg(sequence, phantom):
    def on_message(msg):
        print(f"\r > {msg}", end="")
        return True

    return toolapi.call(
        "wss://tool-spdg-flyio.fly.dev/tool",
        on_message,
        sequence=sequence,
        phantom=phantom,
    )

# %% Run the simulation!
start = time()
print("Run on fly.io")
signal = torch.tensor(sim_spdg(seq, phantom)["signal"])[0, :]
end = time()

print(f"\n\ndone - took {end - start:.3} s")

kspace = signal.reshape(256, 256)
reco = torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(kspace)))

plt.figure()
plt.subplot(211)
plt.plot(signal.abs())
plt.grid()
plt.subplot(223)
plt.imshow(reco.abs(), origin="lower", vmin=0)
plt.axis("off")
plt.subplot(224)
plt.imshow(reco.angle(), origin="lower", vmin=-torch.pi, vmax=torch.pi, cmap="twilight")
plt.axis("off")
plt.show()
