In [None]:
import torch
import torch.nn as nn
import numpy as np
from tqdm.notebook import tqdm
from einops import rearrange, repeat, einsum

from src.data.simulation import Simulation

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [8]:
sim = Simulation(
    "data/simulations/children_0_tubes_0_id_11815.h5",
    "data/antenna/antenna.h5"
)

In [9]:
sim.simulation_raw_data.subject.sum()

65532

In [10]:
field = torch.tensor(sim.simulation_raw_data.field, dtype=torch.float32).to(device)
prop = torch.tensor(sim.simulation_raw_data.properties, dtype=torch.float32).to(device)
mask = torch.tensor(sim.simulation_raw_data.subject, dtype=torch.bool).to(device)

Torch

In [6]:
class PhaseShift(nn.Module):
    def __init__(self):
        super().__init__()
        self.phase = nn.Parameter(torch.randn(8, dtype=torch.float32))
        self.amplitude = nn.Parameter(torch.randn(8, dtype=torch.float32))

    def forward(self, field):
        re_phase = torch.cos(self.phase) * self.amplitude
        im_phase = torch.sin(self.phase) * self.amplitude
        coeffs_real = torch.stack((re_phase, -im_phase), dim=0)
        coeffs_im = torch.stack((im_phase, re_phase), dim=0)
        coeffs = torch.stack((coeffs_real, coeffs_im), dim=0)
        coeffs = repeat(coeffs, 'reimout reim coils -> hf reimout reim coils', hf=2)
        field_shift = einsum(field, coeffs, 'hf reim fieldxyz ... coils, hf reimout reim coils -> hf reimout fieldxyz ...')
        return field_shift

In [7]:
def b1_calc(field):
    b_field = field[1]
    b_field_complex = b_field[0] + 1j * b_field[1]
    return 0.5*(b_field_complex[0] + 1j*b_field_complex[1])

In [8]:
def b1_homogeneity_cost(field, mask):
    b1_plus = b1_calc(field)[mask]
    b1_plus_magnitude = torch.abs(b1_plus)
    b1_plus_mean = torch.mean(b1_plus_magnitude)
    b1_plus_std = torch.std(b1_plus_magnitude)
    return b1_plus_mean / (b1_plus_std + 1e-6)

In [9]:
def sars_calc(field, properties, mask):
    e_field = field[0]
    abs_efield_sq = torch.sum(e_field**2, axis=(0,1))[mask]

    # get the conductivity and density tensors
    conductivity = properties[0][mask]
    density = properties[2][mask]

    return conductivity * abs_efield_sq / density

In [10]:
def sar_cost(field, properties, mask):
    return torch.max(sars_calc(field, properties, mask))

In [11]:
def loss_1(field, mask):
    return -1 * b1_homogeneity_cost(field, mask)

In [12]:
def loss_2(field, prop, mask, lambda_param=0.1):
    return -1 * b1_homogeneity_cost(field, mask) + lambda_param * sar_cost(field, prop, mask)

In [13]:
model = PhaseShift().to(device)
print(model.amplitude)
print(model.phase)

Parameter containing:
tensor([ 0.8403,  0.5319,  1.1106, -1.0601,  2.3744,  1.4144,  1.2957,  0.1888],
       device='cuda:0', requires_grad=True)
Parameter containing:
tensor([ 0.1447,  0.6035,  1.9024,  1.3242,  0.2704, -0.1394,  0.1005,  0.6391],
       device='cuda:0', requires_grad=True)


In [13]:
optmizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [None]:
iterations = 1000
for i in tqdm(range(iterations)):
    optmizer.zero_grad()
    shifted_field = model(field)
    loss = loss_2(shifted_field, prop, mask)
    loss.backward()
    optmizer.step()

    if i % 100 == 0:
        print(f"Iteration {i}: Loss = {loss.item()}")