In [None]:
def Bloch_simul_rot(x,y,z,T1,T2,RF,Gz,time_step,slice_thick,spatial_point,off_range):
    
    #     % units
    # % x, y, z : (spatial_point,1) for each
    # % T1, T2 : [msec]
    # % RF (2*N) : RF(1,:) is [Gauss], RF(2,:) row is angle [degree]
    # % Gz : [mT/m] (1 * N matrix) = [0.1 G/cm]
    # % time_step : pulse duration / sample points [msec]
    # % slice_thick : half of simulating thickness [m]
    # % spatial_point : number of points that are going to be simulated [#]
    # % 
    # % returns
    # % rot: (time series, 3, 3)

    T1 = T1 / 1000
    T2 = T2 / 1000

    length_RF = RF[0,:].shape[0]
    t_int = time_step * 1e-3

    # delta_omega = 2*pi*42.57747892*10^6 * (-1:2/(spatial_point-1):1).'*slice_thick*0.001*Gz;
    delta_omega = 2*np.pi*off_range.unsqueeze(1).repeat(1,256)
    RF_amp = (RF[0,:] * 2 * np.pi * 4257.747892).repeat(spatial_point, 1)
    RF_phase = (RF[1,:]).repeat(spatial_point, 1)*np.pi/180;
    alpha = t_int * torch.sqrt(RF_amp ** 2 + delta_omega **2)
    
    zeta = torch.atan2(RF_amp, delta_omega)
    theta = RF_phase
    ca = torch.cos(alpha)
    sa = torch.sin(alpha)
    cz = torch.cos(zeta)
    sz = torch.sin(zeta)
    ct = torch.cos(theta)
    st = torch.sin(theta)
    E1 = np.exp(- t_int / T1)
    E2 = np.exp(- t_int / T2)
    Mx_x_part = ct*(E2*ct*(sz**2) + cz*(E2*sa*st + E2*ca*ct*cz)) + st*(E2*ca*st - E2*ct*cz*sa)
    Mx_y_part = st*(E2*ct*(sz**2) + cz*(E2*sa*st + E2*ca*ct*cz)) - ct*(E2*ca*st - E2*ct*cz*sa)
    Mx_z_part = E2*ct*cz*sz - sz*(E2*sa*st + E2*ca*ct*cz)
    My_x_part = - ct*(- E2*st*(sz**2) + cz*(E2*ct*sa - E2*ca*cz*st)) - st*(E2*ca*ct + E2*cz*sa*st)
    My_y_part = ct*(E2*ca*ct + E2*cz*sa*st) - st*(- E2*st*(sz**2) + cz*(E2*ct*sa - E2*ca*cz*st))
    My_z_part = sz*(E2*ct*sa - E2*ca*cz*st) + E2*cz*st*sz
    Mz_x_part = ct*(E1*cz*sz - E1*ca*cz*sz) + E1*sa*st*sz
    Mz_y_part = st*(E1*cz*sz - E1*ca*cz*sz) - E1*ct*sa*sz
    Mz_z_part = E1*(cz**2) + E1*ca*(sz**2);


    rot = torch.zeros((spatial_point,length_RF,3,3))
    rot[:,:,0,0] = Mx_x_part
    rot[:,:,0,1] = Mx_y_part
    rot[:,:,0,2] = Mx_z_part
    rot[:,:,1,0] = My_x_part
    rot[:,:,1,1] = My_y_part
    rot[:,:,1,2] = My_z_part
    rot[:,:,2,0] = Mz_x_part
    rot[:,:,2,1] = Mz_y_part
    rot[:,:,2,2] = Mz_z_part
    return rot


In [None]:
def RF_simul(RF_pulse, RF_img, off_range, time_step):
    size = RF_pulse.shape[0]
    RF_pulse_new = torch.zeros((size,2))
    RF_pulse_new[:,0] = torch.sqrt((RF_pulse**2+RF_img**2).squeeze())
    RF_pulse_new[:,1] = torch.angle(torch.complex(RF_pulse, RF_img).squeeze())

    max_rf_amp = torch.max(RF_pulse_new[:,0]) / (2*np.pi *42.577*1e+6*time_step*1e-4)
    rr = 42.577; # MHz/T
    Gz =  40; # mT/m (fixed)
    pos = torch.abs(off_range[0] / rr / Gz); # mm
    mag = (RF_pulse_new[:,0] / torch.max(RF_pulse_new[:,0]) )* max_rf_amp
    ph = (RF_pulse_new[:,1] / np.pi)*180

    pulse = torch.cat([mag.unsqueeze(1), ph.unsqueeze(1)], 1)
    pulse = torch.transpose(pulse, 0, 1)
    gg = torch.transpose(torch.ones((RF_pulse_new.shape[0],1)), 0, 1) * Gz

    freq_shape = off_range.shape[0]
    rot = Bloch_simul_rot(np.zeros([freq_shape,1]),np.zeros([freq_shape,1]),np.ones([freq_shape,1]),1e+10,1e+10,pulse,gg,time_step * 1e+3,pos * 0.001,freq_shape,off_range).to(device)

    return rot

In [3]:
from os import path
import sys

import os
import time
import argparse
import gym
import numpy as np
from scipy.io import loadmat
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
import matplotlib.pyplot as plt
# Select GPU
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

device = 'cuda' if torch.cuda.is_available() else 'cpu'

start_time = time.time()

# Hyperparameter setting
gamma = 2 * np.pi * 42.5775 * 1e+6
N_time_step = 256
max_RF_amp = 0.2 * 1e-4 #T
max_N_iter = 5000
mu = 0.00015  #learning rate
alpha = 0.2795 # trade off between slice profile loss and sar
freq = torch.linspace(-8000, 8000, 2000) #Hz
T = 5.12*1e-3 #sec
k = freq.shape[0]
max_rad = max_RF_amp * gamma * T/N_time_step
time_step = T / N_time_step

# create B matrix
B1 = torch.zeros((3*k, 3*k)).to(device)
B2 = torch.zeros((3*k, 3*k)).to(device)

# create Projection matrix P
P = torch.zeros(3*k, 3*k).to(device)


for i in range(k) : 
    B1[i*3:(i+1)*3, i*3:(i+1)*3] = torch.from_numpy(np.array([[0,0,0], [0,0,1], [0,-1,0]]))
    B2[i*3:(i+1)*3, i*3:(i+1)*3] = torch.from_numpy(np.array([[0,0,-1], [0,0,0], [1,0,0]]))
    P[3*i+2, 3*i+2] = 1

In [4]:
from scipy.io import loadmat, savemat

a = loadmat('./SLR_inv_oc.mat')
slr = a['slr']
import torch
import numpy as np
rot = RF_simul(torch.from_numpy(slr), torch.zeros((256,1)).type(torch.DoubleTensor), torch.linspace(-8000,8000,2000), 5.12*1e-3/256)


In [5]:
Mt = torch.zeros([k, N_time_step+1, 3]).to(device)
Mt[:, 0, 2] = 1
torch.set_printoptions(threshold=5000)

for f in range(1, N_time_step+1):
    Mt[:, f, 0] = (Mt[:, f-1, 0] * rot[:, f-1, 0, 0]) +  (Mt[:, f-1, 1] * rot[:, f-1, 0, 1]) +  (Mt[:, f-1, 2] * rot[:, f-1, 0, 2] )
    Mt[:, f, 1] = (Mt[:, f-1, 0] * rot[:, f-1, 1, 0]) +  (Mt[:, f-1, 1] * rot[:, f-1, 1, 1]) +  (Mt[:, f-1, 2] * rot[:, f-1, 1, 2] )
    Mt[:, f, 2] = (Mt[:, f-1, 0] * rot[:, f-1, 2, 0]) +  (Mt[:, f-1, 1] * rot[:, f-1, 2, 1]) +  (Mt[:, f-1, 2] * rot[:, f-1, 2, 2])

D = Mt[:, -1, :].squeeze().reshape(-1, 1)
K = torch.mm(P, D);


In [6]:
# Create input vectors
preset = loadmat('./OC_input_inv_STA.mat')
w1 = np.array(preset['w1'], dtype=np.float32)
w2 = np.array(preset['w2'], dtype=np.float32)
w1 = torch.FloatTensor(w1)
w2 = torch.FloatTensor(w2)
w1 = w1.to(device)
w2 = w2.to(device)

In [7]:
sar_loss = []
L2_loss  = []
total_loss = []

for e in range(max_N_iter):
    Mt = torch.zeros([k, N_time_step+1, 3]).to(device)
    Mt[:, 0, 2] = 1
    rot = (RF_simul(torch.transpose(w1,0,1), torch.transpose(w2,0,1), freq, T/N_time_step)).squeeze()

    for f in range(1 ,N_time_step+1):
        Mt[:, f, 0] = Mt[:, f-1, 0] * rot[:, f-1, 0, 0] +  Mt[:, f-1, 1] * rot[:, f-1, 0, 1] +  Mt[:, f-1, 2] * rot[:, f-1, 0, 2] 
        Mt[:, f, 1] = Mt[:, f-1, 0] * rot[:, f-1, 1, 0] +  Mt[:, f-1, 1] * rot[:, f-1, 1, 1] +  Mt[:, f-1, 2] * rot[:, f-1, 1, 2] 
        Mt[:, f, 2] = Mt[:, f-1, 0] * rot[:, f-1, 2, 0] +  Mt[:, f-1, 1] * rot[:, f-1, 2, 1] +  Mt[:, f-1, 2] * rot[:, f-1, 2, 2] 
    
    Mt = Mt[:, 1:, :]

    M_T = (Mt[:, -1, :]).squeeze().reshape(-1, 1)
    lambda_T = torch.transpose((torch.mm(torch.mm(torch.transpose(P, 0,1),P), M_T) - torch.mm(torch.transpose(P,0,1), K)),0,1)

    lambda_ = torch.zeros((k, N_time_step, 3)).to(device)
    lambda_[:, -1, 0] = lambda_T[:, 0:3*k:3]
    lambda_[:, -1, 1] = lambda_T[:, 1:3*k:3]
    lambda_[:, -1, 2] = lambda_T[:, 2:3*k:3]

    for b in range(N_time_step-1, 0, -1):
        lambda_[:, b-1, 0] = lambda_[:, b, 0] * rot[:, b, 0, 0] + lambda_[:, b, 1] * rot[:, b, 1, 0] + lambda_[:, b, 2] * rot[:, b, 2, 0]
        lambda_[:, b-1, 1] = lambda_[:, b, 0] * rot[:, b, 0, 1] + lambda_[:, b, 1] * rot[:, b, 1, 1] + lambda_[:, b, 2] * rot[:, b, 2, 1]
        lambda_[:, b-1, 2] = lambda_[:, b, 0] * rot[:, b, 0, 2] + lambda_[:, b, 1] * rot[:, b, 1, 2] + lambda_[:, b, 2] * rot[:, b, 2, 2]

    if (e) % 30 == 0:
        time_step = 5.12 * 1e-3 / 256 
        to_gauss = 2 * np.pi * 42.5775 * 1e+6 * time_step * 1e-4
        sar_slr = np.sum((slr/to_gauss)**2) * time_step * 1e+6

        sar_rf = torch.sum((w1**2 + w2 **2)) / (to_gauss**2) * time_step * 1e+6
        print("★★★★★★★★★★★★★★★★★★★★★★★★")
        print(f"Iteration : {e}")
        print(f"SAR of SLR : {sar_slr:.5f}")
        print(f"SAR of RF : {sar_rf:.5f}")
        print(f"SAR reduction : {(sar_slr - sar_rf)/sar_slr *100:.4f} %")
        energy = torch.sum(w1**2 + w2**2)
        M_ = (Mt[:, -1, :]).squeeze().reshape(-1, 1)
        dist = torch.matmul(P, M_) - K
        phi_d = (1/2 * (torch.mm(torch.transpose(dist, 0, 1), dist))).item()
        print(f"iteration = {e}, phi_d = {phi_d}, energy = {energy}")


    # Update rf pulse
    for u in range(N_time_step):
        M = Mt[:, u, :].squeeze().reshape(-1,1)
        lamb = lambda_[:, u, :].squeeze().reshape(-1, 1).to(device)
        w1[:, u] = w1[:, u] - mu * (torch.mm(torch.mm(torch.transpose(lamb, 0, 1), B1), M) + alpha * w1[:, u])
        #w2[:, u] = w2[:, u] - mu * (torch.mm(torch.mm(torch.transpose(lamb, 0, 1), B2), M) + alpha * w2[:, u])

    M_ = (Mt[:, -1, :]).squeeze().reshape(-1, 1)
    dist = torch.mm(P, M_) - K
    phi_d = (1/2 * (torch.mm(torch.transpose(dist, 0, 1), dist))).item()
    energy = torch.sum(w1**2 + w2**2)
    sar_loss.append(energy)
    L2_loss.append(phi_d)
    loss = phi_d + alpha * energy 
    if (e > 500) :
        if (loss > total_loss[-1]) :
            break
    total_loss.append(loss)


print(f"Total time: {time.time() - start_time:.1f}")

★★★★★★★★★★★★★★★★★★★★★★★★
Iteration : 0
SAR of SLR : 19.59590
SAR of RF : 19.05526
SAR reduction : 2.7589 %
iteration = 0, phi_d = 18.92304229736328, energy = 0.27275004982948303
★★★★★★★★★★★★★★★★★★★★★★★★
Iteration : 30
SAR of SLR : 19.59590
SAR of RF : 20.69562
SAR reduction : -5.6120 %
iteration = 30, phi_d = 0.014026605524122715, energy = 0.29622966051101685
★★★★★★★★★★★★★★★★★★★★★★★★
Iteration : 60
SAR of SLR : 19.59590
SAR of RF : 20.60730
SAR reduction : -5.1613 %
iteration = 60, phi_d = 0.006514052394777536, energy = 0.2949654459953308
★★★★★★★★★★★★★★★★★★★★★★★★
Iteration : 90
SAR of SLR : 19.59590
SAR of RF : 20.54689
SAR reduction : -4.8530 %
iteration = 90, phi_d = 0.004710513167083263, energy = 0.2941007614135742
★★★★★★★★★★★★★★★★★★★★★★★★
Iteration : 120
SAR of SLR : 19.59590
SAR of RF : 20.49272
SAR reduction : -4.5766 %
iteration = 120, phi_d = 0.003909674938768148, energy = 0.29332542419433594
★★★★★★★★★★★★★★★★★★★★★★★★
Iteration : 150
SAR of SLR : 19.59590
SAR of RF : 20.43900
SA

In [8]:
savemat("OC_result_inv8_best_mu00002_STA.mat", {"L2loss": L2_loss, "sar_loss" : sar_loss, "w1" : w1, "w2" : w2, "total loss" : total_loss, "iteration" : e})