In [None]:
import time
import mosaic
from mosaic import tessera

@tessera
class Solver1:
    def __init__(self, data):
        self.data = data

    def solve(self, data):
        print('Solve 1')
        self.data = self.data + data

        time.sleep(10)
        print('Done 1')

        return self.data

    def solve_more(self):
        print('Solve More 1')
        time.sleep(5)
        print('Done More 1')


@tessera
class Solver2:
    def __init__(self):
        self.data = 0

    def solve(self, data):
        print('Solve 2')
        self.data = data*2

        time.sleep(10)
        print('Done 2')

        return self.data

    def solve_more(self):
        print('Solve More 2')
        time.sleep(5)
        print('Done More 2')

In [None]:

from stride import Operator


class H(Operator):

    def forward(self, x):
        z = x.alike()
        # Perform some operations with x to produce z
        return z

    def adjoint(self, grad_z, x, **kwargs):
        grad_x = x.alike()
        # Calculate the gradient wrt x
        return grad_x


class G(Operator):

    def forward(self, z):
        y = z.alike()
        # Perform some operations with z to produce y
        return y

    def adjoint(self, grad_y, z, **kwargs):
        grad_z = z.alike()
        # Calculate the gradient wrt z
        return grad_z


class F(Operator):

    def forward(self, y):
        w = y.alike()
        # Perform some operations with y to produce w
        return w

    def adjoint(self, grad_w, y, **kwargs):
        grad_y = y.alike()
        # Calculate the gradient wrt y
        return grad_y


h = H()
g = G()
f = F()

In [None]:
await mosaic.interactive('on', num_workers=8)
mosaic.runtime()

In [None]:
import numpy as np

array = np.ones((1024, 1024, 1), dtype=np.float32)

# These objects will be created remotely
solver_1 = Solver1.remote(array)
solver_2 = Solver2.remote()

In [None]:
solver_1

In [None]:
# Check the current value of the attribute
np.sum(await solver_1.data)

In [None]:
# These will run in parallel
# The calls will return immediately by creating a remote task
task_1 = solver_1.solve(array)
task_2 = solver_2.solve(array)
task_1

In [None]:
# Wait until the remote tasks are finished
await task_1
await task_2

In [None]:
# The results of the tasks stay in the remote worker
# until we request it back
result_1 = await task_1.result()
result_2 = await task_2.result()

print(result_1.shape)
print(result_2.shape)

In [None]:
np.sum(await solver_1.data)

In [None]:
# These will wait for each other because
# their results depend on each other
task_1 = solver_1.solve(array)
task_2 = solver_2.solve(task_1)

In [None]:
# Wait until the remote tasks are finished
# Now we only need to wait for the second task
await task_2

In [None]:
obj = dict(a=1, b=2)

runtime = mosaic.runtime()
ref = await runtime.put(obj)

ref

In [None]:
await ref.value()

In [None]:
await runtime.put(obj, publish=True)

In [None]:
await ref.drop()

In [None]:
async def run():
    result = gpu_test()
    return result

In [None]:
import torch
import asyncio

print(torch.cuda.is_available())
print(torch.cuda.device_count())

# Get the runtime
runtime = mosaic.runtime()

print(f"Number of Mosaic workers: {runtime.num_workers}")
print(f"Available worker IDs: {list(runtime._workers.keys())}")

# Check GPU availability for each worker
def gpu_test():
    try:
        if torch.cuda.is_available():
            device = torch.device("cuda")
            x = torch.rand(1000, 1000, device=device)
            result = torch.matmul(x, x)
            return "GPU operation successful"
        else:
            return "GPU not available"
    except Exception as e:
        return f"Error during GPU test: {str(e)}"

async def run():
    result = gpu_test()
    return result

# Await the async function and print its result
result = await run()
print(result)

In [None]:
import torch
print(torch.cuda.device_count())  # Should output 4
for i in range(torch.cuda.device_count()):
    print(f"GPU {i}: {torch.cuda.get_device_name(i)}")


In [None]:
from stride.utils import fetch, wavelets
from stride import Space, Time, Grid
from stride import Problem, ScalarField
import numpy as np
import h5py
import matplotlib.pyplot as plt
%matplotlib widget
import os
os.environ['STRIDE_BACKEND'] = 'cuda'


# Load the MIDA model
with h5py.File('/home/leozaroff/mida_model.h5', 'r') as f:
    vp_data = np.array(f['vp'])
    print(f"Model data shape: {vp_data.shape}")
    print(f"Data range: min={np.min(vp_data)}, max={np.max(vp_data)}")
    print(f"Data type: {vp_data.dtype}")

# Paper settings:
# - Sub-MHz frequencies (100-850 kHz)
# - 500μm isotropic resolution
# - 1024 transducers around the head
spacing = (0.5e-3, 0.5e-3, 0.5e-3)  # 500μm spacing as per paper
space = Space(shape=vp_data.shape, 
             extra=(50, 50, 50),     # Extra space for wave propagation
             absorbing=(40, 40, 40),  # Absorbing boundary conditions
             spacing=spacing)

# Time settings based on paper's sub-MHz frequency requirements
dt = 0.08e-6  # Time step
T = 240e-6    # Total time (240 μs as per paper)
nt = int(T/dt)
time = Time(start=0, step=dt, num=nt)

grid = Grid(space, time)

# Create problem
problem = Problem(name='head3D', space=space, time=time)

# Create transducers - paper uses 1024 transducers
num_locations = 1024  # As specified in paper
problem.transducers.default()

# Set up geometry for transducer placement
# Paper places transducers around the head in 3D, avoiding face
radius = ((space.limit[0] - 30e-3) / 2,
         (space.limit[1] - 30e-3) / 2,
         (space.limit[2] - 30e-3) / 2)
centre = (space.limit[0] / 2,
         space.limit[1] / 2,
         space.limit[2] / 2)

# Create ellipsoidal arrangement avoiding the face area
problem.geometry.default('ellipsoidal', num_locations, radius, centre,
                       theta=np.pi * 0.75,  # Reduced angle to avoid face
                       threshold=0.5)

# Setup acquisitions
problem.acquisitions.default()

# Create wavelets - paper uses 100-850 kHz bandwidth
f_centre = 0.4e6  # 400 kHz center frequency
n_cycles = 3      # 3-cycle tone burst as mentioned in paper

for shot in problem.acquisitions.shots:
    shot.wavelets.data[0, :] = wavelets.tone_burst(f_centre, n_cycles, time.num, time.step)

# Load velocity model
vp_true = ScalarField(name='vp', grid=grid, data=vp_data)
problem.medium.add(vp_true)

# Enhanced visualization
plt.figure(figsize=(15, 5))

# Plot three orthogonal slices
plt.subplot(131)
plt.imshow(vp_data[vp_data.shape[0]//2, :, :], cmap='viridis')
plt.title('Sagittal View')
plt.colorbar(label='Velocity (m/s)')

plt.subplot(132)
plt.imshow(vp_data[:, vp_data.shape[1]//2, :], cmap='viridis')
plt.title('Coronal View')
plt.colorbar(label='Velocity (m/s)')

plt.subplot(133)
plt.imshow(vp_data[:, :, vp_data.shape[2]//2], cmap='viridis')
plt.title('Axial View')
plt.colorbar(label='Velocity (m/s)')

plt.tight_layout()
plt.show()

# Print configuration details
print("\nProblem configuration:")
print(f"Space shape: {space.shape}")
print(f"Spatial resolution: {spacing[0]*1e3} mm")
print(f"Number of transducers: {num_locations}")
print(f"Center frequency: {f_centre/1e6} MHz")
print(f"Time samples: {time.num}")
print(f"Total simulation time: {T*1e6} μs")

# Try problem visualization- doesn't really work right now
try:
    problem.plot()
except Exception as e:
    print(f"\nError in problem visualization: {e}")
    print("Attempting alternative visualization...")
    try:
        fig = plt.figure(figsize=(10, 10))
        ax = fig.add_subplot(111, projection='3d')
        
        # Plot transducer positions
        positions = problem.geometry.coordinates
        ax.scatter(positions[:, 0], positions[:, 1], positions[:, 2], 
                  c='red', marker='o', label='Transducers')
        
        # Add a semi-transparent volume to show head boundaries
        x, y, z = np.meshgrid(np.linspace(0, space.limit[0], 10),
                             np.linspace(0, space.limit[1], 10),
                             np.linspace(0, space.limit[2], 10))
        ax.scatter(x, y, z, alpha=0.1)
        
        ax.set_xlabel('X (m)')
        ax.set_ylabel('Y (m)')
        ax.set_zlabel('Z (m)')
        ax.legend()
        plt.show()
    except Exception as e:
        print(f"Alternative visualization failed: {e}")


In [None]:
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider
import numpy as np

# Assuming 'problem' is your Stride Problem object
num_shots = len(problem.acquisitions.shots)
time_axis = time.step * np.arange(time.num)

def plot_wavelet(shot_index):
    wavelet_data = problem.acquisitions.shots[shot_index].wavelets.data[0, :]
    
    plt.figure(figsize=(12, 6))
    plt.plot(time_axis, wavelet_data)
    plt.xlabel("Time (s)")
    plt.ylabel("Amplitude")
    plt.title(f"Wavelet Signal for Shot {shot_index + 1}")
    plt.grid(True)
    plt.ylim(-1.1, 1.1)  # Adjust as needed
    plt.show()

interact(plot_wavelet, 
         shot_index=IntSlider(min=0, max=num_shots-1, step=1, value=0, 
                              description='Shot Index:',
                              style={'description_width': 'initial'}))


In [None]:
  import matplotlib.pyplot as plt
  from mpl_toolkits.mplot3d import Axes3D

  fig = plt.figure()
  ax = fig.add_subplot(111, projection='3d')
  coords = problem.geometry.coordinates
  ax.scatter(coords[:, 0], coords[:, 1], coords[:, 2], c='r', marker='o')
  ax.set_xlabel('X axis')
  ax.set_ylabel('Y axis')
  ax.set_zlabel('Z axis')
  plt.show()


In [None]:
from stride import IsoAcousticDevito

pde = IsoAcousticDevito.remote(grid=problem.grid, len=runtime.num_workers, device="cuda")


In [None]:
# Get all remaining shot IDs
shot_ids = problem.acquisitions.remaining_shot_ids

# Run an asynchronous loop across all shot IDs
@runtime.async_for(shot_ids)
async def loop(worker, shot_id):
    runtime.logger.info('Giving shot %d to %s' % (shot_id, worker.uid))

    # Fetch one sub-problem corresponding to a shot ID
    sub_problem = problem.sub_problem(shot_id)
    
    # Access the source wavelets of this shot
    wavelets = sub_problem.shot.wavelets
    
    # Execute the PDE forward
    traces = await pde(wavelets, vp_true,
                       problem=sub_problem,
                       runtime=worker).result()

    runtime.logger.info('Shot %d retrieved' % sub_problem.shot_id)

    # Store the retrieved traces into the shot
    shot = problem.acquisitions.get(shot_id)
    shot.observed.data[:] = traces.data

    runtime.logger.info('Retrieved traces for shot %d' % sub_problem.shot_id)

# Because this is an asynchronous loop, it needs to be awaited 
_ = await loop

# Plot the result
_ = problem.acquisitions.plot()

In [None]:
import torch
print("Running PDE on:", torch.cuda.current_device())


In [None]:
from stride import forward
#traces = await pde(wavelets, vp_true, problem=problem, runtime=worker)
# Run default forward workflow
#await forward(problem, pde, vp_true, dump=False)

Starting model: 

Unlike in the forward problem, the speed of sound field is unknown to us at this point. We need to choose a starting assumption to begin optimizing from. Here we assume a homogenous speed of 1500 m/s (~water's speed of sound).

In [None]:
# Initialize starting model for inversion
vp = ScalarField.parameter(name='vp', grid=grid, needs_grad=True)
vp.fill(1500.) # Initialize with a constant velocity

problem.medium.add(vp)

Imaging operators:

Define operators for our loss function, processing our source wavelets and the modelled/observed data traces, and an optimizer to update the speed of sound model after each iteration (here we use gradient descent).

We also limited the speed of sound to between 1400 and 1700 m/s, which will need to change when we introduce bone.

In [None]:
from stride import L2DistanceLoss 

# Set up optimization components
loss = L2DistanceLoss.remote(len=runtime.num_workers)

from stride import ProcessWavelets, ProcessObserved, ProcessWaveletsObserved, ProcessTraces


process_wavelets = ProcessWavelets.remote(len=runtime.num_workers)
process_observed = ProcessObserved.remote(len=runtime.num_workers)
process_wavelets_observed = ProcessWaveletsObserved.remote(len=runtime.num_workers)
process_traces = ProcessTraces.remote(len=runtime.num_workers)

from stride import GradientDescent, ProcessGlobalGradient, ProcessModelIteration

# Configure optimization parameters
step_size = 10
process_grad = ProcessGlobalGradient()
process_model = ProcessModelIteration(min=1400., max=1700.)

optimiser = GradientDescent(vp, step_size=step_size,
                            process_grad=process_grad,
                            process_model=process_model)

Inverse problem -- estimating the speed of sound:

We use a multi-frequency approach, starting with lower frequencies and then moving to higher frequencies for more detail. This is done by dividing the optimization into blocks, each with a max frequency.

During each iteration, only a random subset of the shots will actually be used. 

Here we use the utility function "adjoint" to run this optimization loop. If you want to see more detailed code, take a look at the 2d example notebook.

In [None]:
from stride import OptimisationLoop

# Clear the previous Devito operators
await pde.clear_operators()

optimisation_loop = OptimisationLoop()

# Specify a series of frequency bands, which we will introduce gradually 
# into the inversion in order to better condition it
# Run multi-frequency inversion

max_freqs = [0.3e6, 0.4e6]

num_blocks = len(max_freqs)
num_iters = 4

# Start iterating over each block in the optimisation
for block, f_max in optimisation_loop.blocks(num_blocks, max_freqs):

    # Proceed through every iteration in the block
    for iteration in block.iterations(num_iters):
        runtime.logger.info('Starting iteration %d (out of %d), '
                            'block %d (out of %d)' %
                            (iteration.id+1, block.num_iterations, block.id+1,
                             optimisation_loop.num_blocks))

        # Select some shots for this iteration
        shot_ids = problem.acquisitions.select_shot_ids(num=15, randomly=True)

        # Clear the gradient buffers of the variable
        vp.clear_grad()

        # Asynchronously loop over all the selected shot IDs
        @runtime.async_for(shot_ids)
        async def loop(worker, shot_id):
            runtime.logger.info('Giving shot %d to %s' % (shot_id, worker.uid))

            # Fetch one sub-problem corresponding to the shot ID
            sub_problem = problem.sub_problem(shot_id)
            wavelets = sub_problem.shot.wavelets
            observed = sub_problem.shot.observed

            # Pre-process the wavelets and observed
            wavelets = process_wavelets(wavelets, f_max=f_max, filter_relaxation=0.75, runtime=worker)
            observed = process_observed(observed, f_max=f_max, filter_relaxation=0.75, runtime=worker)
            processed = process_wavelets_observed(wavelets, observed, f_max=f_max, runtime=worker)
            wavelets = processed.outputs[0]
            observed = processed.outputs[1]
            
            # Execute the PDE forward
            modelled = pde(wavelets, vp, problem=sub_problem, runtime=worker)

            # Pre-process the modelled and observed traces
            traces = process_traces(modelled, observed, f_max=f_max, filter_relaxation=0.75, runtime=worker)
            # and use these pre-processed versions to calculate the
            # value of the loss_freq function
            fun = loss(traces.outputs[0], traces.outputs[1],
                       problem=sub_problem, runtime=worker)

            # run adjoint
            fun_value = await fun.remote.adjoint().result()

            iteration.add_loss(fun_value)
            runtime.logger.info('Functional value for shot %d: %s' % (shot_id, fun_value))

            runtime.logger.info('Retrieved gradient for shot %d' % sub_problem.shot_id)

        # Because this is an async loop, it needs to be awaited    
        _ = await loop
        # Update the vp with the calculated gradient by taking a step with the optimiser
        await optimiser.step()

        runtime.logger.info('Done iteration %d (out of %d), '
                            'block %d (out of %d) - Total loss_freq %e' %
                            (iteration.id+1, block.num_iterations, block.id+1,
                             optimisation_loop.num_blocks, iteration.total_loss))
        runtime.logger.info('====================================================================')

# Plot the vp afterwards   
vp.plot()

"""for block, freq in optimisation_loop.blocks(num_blocks, max_freqs):
    await adjoint(problem, pde, loss,
        optimisation_loop, optimiser, vp,
        num_iters=num_iters,
        select_shots=dict(num=15, randomly=True),
        f_max=freq)"""

vp.plot()

Tear down the mosaic runtime:

In [None]:
await mosaic.interactive('off')