In [1]:
import torch
import torch.nn as nn
import numpy as np
from tqdm.notebook import tqdm
from einops import rearrange, repeat, einsum
from skimage.measure import label, regionprops

from src.data.simulation import Simulation

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
sim = Simulation(
    "data/simulations/children_0_tubes_0_id_11815.h5",
    "data/antenna/antenna.h5"
)

In [4]:
labeled_mask = label(sim.simulation_raw_data.subject)
regions = regionprops(labeled_mask)
coords=regions[0].bbox
field_np = sim.simulation_raw_data.field[:, :, :, int(coords[0]):int(coords[3]), int(coords[1]):int(coords[4]), int(coords[2]):int(coords[5]), :]
prop_np = sim.simulation_raw_data.properties[:, int(coords[0]):int(coords[3]), int(coords[1]):int(coords[4]), int(coords[2]):int(coords[5])]

In [5]:
field = torch.tensor(field_np, dtype=torch.float32).to(device)
prop = torch.tensor(prop_np, dtype=torch.float32).to(device)

Torch

In [6]:
class PhaseShift(nn.Module):
    def __init__(self):
        super().__init__()
        self.phase = nn.Parameter(torch.randn(8, dtype=torch.float32))
        self.amplitude = nn.Parameter(torch.randn(8, dtype=torch.float32))

    def forward(self, field):
        re_phase = torch.cos(self.phase) * self.amplitude
        im_phase = torch.sin(self.phase) * self.amplitude
        coeffs_real = torch.stack((re_phase, -im_phase), dim=0)
        coeffs_im = torch.stack((im_phase, re_phase), dim=0)
        coeffs = torch.stack((coeffs_real, coeffs_im), dim=0)
        coeffs = repeat(coeffs, 'reimout reim coils -> hf reimout reim coils', hf=2)
        field_shift = einsum(field, coeffs, 'hf reim fieldxyz ... coils, hf reimout reim coils -> hf reimout fieldxyz ...')
        return field_shift

In [7]:
def b1_calc(field):
    b_field = field[1]
    b_field_complex = b_field[0] + 1j * b_field[1]
    return 0.5*(b_field_complex[0] + 1j*b_field_complex[1])

In [8]:
def b1_homogeneity_cost(field):
    b1_plus = b1_calc(field)
    b1_plus_magnitude = torch.abs(b1_plus)
    b1_plus_mean = torch.mean(b1_plus_magnitude)
    b1_plus_std = torch.std(b1_plus_magnitude)
    return b1_plus_mean / (b1_plus_std + 1e-6)

In [9]:
def sars_calc(field, properties):
    e_field = field[0]
    abs_efield_sq = torch.sum(e_field**2, axis=(0,1))

    # get the conductivity and density tensors
    conductivity = properties[0]
    density = properties[2]

    return conductivity * abs_efield_sq / density

In [10]:
def sar_cost(field, properties):
    return torch.max(sars_calc(field, properties))

In [11]:
def loss_1(field):
    return -1 * b1_homogeneity_cost(field)

In [12]:
def loss_2(field, prop, lambda_param=0.1):
    return -1 * b1_homogeneity_cost(field) + lambda_param * sar_cost(field, prop)

In [13]:
model = PhaseShift().to(device)
print(model.amplitude)
print(model.phase)

Parameter containing:
tensor([ 0.8866, -0.2040, -0.2662, -0.3972, -0.3892, -0.7052,  0.0299, -0.1903],
       device='cuda:0', requires_grad=True)
Parameter containing:
tensor([-1.8614, -1.3950,  0.5299, -0.4911, -0.0032, -0.3384,  1.3015,  0.7771],
       device='cuda:0', requires_grad=True)


In [14]:
optmizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [15]:
iterations = 1000
for i in tqdm(range(iterations)):
    optmizer.zero_grad()
    shifted_field = model(field)
    loss = loss_2(shifted_field, prop)
    loss.backward()
    optmizer.step()

    if i % 100 == 0:
        print(f"Iteration {i}: Loss = {loss.item()}")

  0%|          | 0/1000 [00:00<?, ?it/s]

Iteration 0: Loss = -0.5671815872192383
Iteration 100: Loss = -2.750995397567749
Iteration 200: Loss = -2.8596324920654297
Iteration 300: Loss = -2.908470392227173
Iteration 400: Loss = -2.928636312484741
Iteration 500: Loss = -2.935072422027588
Iteration 600: Loss = -2.9365181922912598
Iteration 700: Loss = -2.9352550506591797
Iteration 800: Loss = -2.937091112136841
Iteration 900: Loss = -2.9370503425598145
