In [1]:
def Bloch_simul_rot(x,y,z,T1,T2,RF,Gz,time_step,slice_thick,spatial_point,off_range, device):
    
    #     % units
    # % x, y, z : (spatial_point,1) for each
    # % T1, T2 : [msec]
    # % RF (B x 2 x N) : B : batch RF(:, 0,:) is [Gauss], RF(:, 1,:) row is angle [degree] 
    # % Gz : [mT/m] (1 * N matrix) = [0.1 G/cm]
    # % time_step : pulse duration / sample points [msec]
    # % slice_thick : half of simulating thickness [m]
    # % spatial_point : number of points that are going to be simulated [#]
    # % 
    # % returns
    # % rot: (B , time series, 3, 3)

    T1 = T1 / 1000
    T2 = T2 / 1000
    batch =  RF.shape[0]

    length_RF = RF[0, 0, :].shape[0]
    t_int = time_step * 1e-3

    # delta_omega = 2*pi*42.57747892*10^6 * (-1:2/(spatial_point-1):1).'*slice_thick*0.001*Gz;
    delta_omega = 2*np.pi*off_range.unsqueeze(1).unsqueeze(0).repeat(1, 1, 256).to(device)
    RF_amp = (RF[:, 0,:].unsqueeze(1) * 2 * np.pi * 4257.747892).repeat(1, spatial_point, 1).to(device)
    RF_phase = ((RF[:, 1,:]).unsqueeze(1).repeat(1, spatial_point, 1)*np.pi/180).to(device)
    alpha = t_int * torch.sqrt(RF_amp ** 2 + delta_omega **2)

    zeta = torch.atan2(RF_amp, delta_omega)
    theta = RF_phase ## B x spatial_points x N (N : RF length)
    
    ca = torch.cos(alpha)
    sa = torch.sin(alpha).to(device)
    cz = torch.cos(zeta).to(device)
    sz = torch.sin(zeta).to(device)
    ct = torch.cos(theta).to(device)
    st = torch.sin(theta).to(device)
    E1 = np.exp(- t_int / T1)
    E2 = np.exp(- t_int / T2)
    
    rot = torch.zeros((batch, spatial_point,length_RF,3,3), device=device)
    rot[:,:,:,0,0] = ct*(E2*ct*(sz**2) + cz*(E2*(sa*st + ca*ct*cz))) + st*(E2*(ca*st - ct*cz*sa))
    rot[:,:,:,0,1] = st*(E2*ct*(sz**2) + cz*(E2*(sa*st + ca*ct*cz))) - ct*(E2*(ca*st - ct*cz*sa))
    rot[:,:,:,0,2] = E2*ct*cz*sz - sz*(E2*(sa*st + ca*ct*cz))
    rot[:,:,:,1,0] = - ct*(- E2*st*(sz**2) + cz*(E2*(ct*sa - ca*cz*st))) - st*(E2*(ca*ct + cz*sa*st))
    rot[:,:,:,1,1] = ct*(E2*(ca*ct + cz*sa*st)) - st*(- E2*st*(sz**2) + cz*(E2*(ct*sa - ca*cz*st)))
    rot[:,:,:,1,2] = sz*(E2*(ct*sa - ca*cz*st)) + E2*cz*st*sz
    rot[:,:,:,2,0] = ct*(E1*(1-ca)*cz*sz) + E1*sa*st*sz
    rot[:,:,:,2,1] = st*(E1*(1-ca)*cz*sz) - E1*ct*sa*sz
    rot[:,:,:,2,2] = E1*(cz**2 + ca*(sz**2))

#     rot[:,:,:,0,0] = Mx_x_part
#     rot[:,:,:,0,1] = Mx_y_part
#     rot[:,:,:,0,2] = Mx_z_part
#     rot[:,:,:,1,0] = My_x_part
#     rot[:,:,:,1,1] = My_y_part
#     rot[:,:,:,1,2] = My_z_part
#     rot[:,:,:,2,0] = Mz_x_part
#     rot[:,:,:,2,1] = Mz_y_part
#     rot[:,:,:,2,2] = Mz_z_part
    return rot


In [2]:
def RF_simul(RF_pulse, RF_img, off_range, time_step):
    # RF_pulse, RF_img  : (B x N x 1) 
    # off_range = (spatial_points)
    size = RF_pulse.shape[1]
    batch = RF_pulse.shape[0]
    RF_pulse_new = torch.zeros(batch,size,2)

    RF_pulse_new[:, :, 0] = torch.sqrt((RF_pulse**2+RF_img**2)).squeeze()
    RF_pulse_new[:, :, 1] = torch.angle(torch.complex(RF_pulse, RF_img)).squeeze()

    max_rf_amp = torch.amax(RF_pulse_new[:,:,0],dim=1)/ (2*np.pi *42.577*1e+6*time_step*1e-4)

    rr = 42.577; # MHz/T
    Gz =  40; # mT/m (fixed)
    pos = torch.abs(off_range[0] / rr / Gz); # mm
    max_rf = torch.amax(RF_pulse_new[:,:,0], dim=1).view(batch, 1)
    mag = (RF_pulse_new[:,:,0] / max_rf ) * max_rf_amp.view(batch, 1)
    ph = (RF_pulse_new[:,:,1] / np.pi)*180
    

    pulse = torch.cat([mag.unsqueeze(2), ph.unsqueeze(2)],2)
    pulse = torch.transpose(pulse, 1, 2)
    gg = torch.transpose(torch.ones((RF_pulse_new.shape[0],1)), 0, 1) * Gz

    freq_shape = off_range.shape[0]
    rot = Bloch_simul_rot(np.zeros([freq_shape,1]),np.zeros([freq_shape,1]),np.ones([freq_shape,1]), 1e+10, 1e+10, pulse, gg, time_step * 1e+3,pos * 0.001,freq_shape,off_range, device).to(device)

    return rot

In [3]:
from os import path
import sys

import os
import time
import argparse
import gym
import numpy as np
from scipy.io import loadmat
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
import matplotlib.pyplot as plt
# Select GPU
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

device = 'cuda' if torch.cuda.is_available() else 'cpu'


# Hyperparameter setting
gamma = 2 * np.pi * 42.5775 * 1e+6
N_time_step = 256
max_RF_amp = 0.2 * 1e-4 #T
max_N_iter =500
mu = 0.0002  #learning rate
alpha = 0.2795 # trade off between slice profile loss and sar
freq = torch.linspace(-8000, 8000, 3000, device=device) #Hz
T = 5.12*1e-3 #sec
k = freq.shape[0]
max_rad = max_RF_amp * gamma * T/N_time_step
time_step = T / N_time_step

# create B matrix
B1 = torch.zeros((3*k, 3*k), device = device)
B2 = torch.zeros((3*k, 3*k), device = device)

# create Projection matrix P
P = torch.zeros((3*k, 3*k), device = device)


for i in range(k) : 
    B1[i*3:(i+1)*3, i*3:(i+1)*3] = torch.from_numpy(np.array([[0,0,0], [0,0,1], [0,-1,0]]))
    B2[i*3:(i+1)*3, i*3:(i+1)*3] = torch.from_numpy(np.array([[0,0,-1], [0,0,0], [1,0,0]]))
    P[3*i+2, 3*i+2] = 1

In [4]:
from scipy.io import loadmat, savemat

a = loadmat('./SLR_inv_oc.mat')
slr = a['slr']

import torch
import numpy as np
rot = RF_simul(torch.from_numpy(slr).unsqueeze(0), torch.zeros((1, 256,1), device = device).type(torch.DoubleTensor), torch.linspace(-8000,8000,3000), 5.12*1e-3/256)

rot = rot.squeeze()
batch = 2

In [5]:
Mt = torch.zeros([k, N_time_step+1, 3], device = device)
Mt[:, 0, 2] = 1
torch.set_printoptions(threshold=5000)

for f in range(1, N_time_step+1):
    Mt[:, f, 0] = (Mt[:, f-1, 0] * rot[:, f-1, 0, 0]) +  (Mt[:, f-1, 1] * rot[:, f-1, 0, 1]) +  (Mt[:, f-1, 2] * rot[:, f-1, 0, 2])
    Mt[:, f, 1] = (Mt[:, f-1, 0] * rot[:, f-1, 1, 0]) +  (Mt[:, f-1, 1] * rot[:, f-1, 1, 1]) +  (Mt[:, f-1, 2] * rot[:, f-1, 1, 2])
    Mt[:, f, 2] = (Mt[:, f-1, 0] * rot[:, f-1, 2, 0]) +  (Mt[:, f-1, 1] * rot[:, f-1, 2, 1]) +  (Mt[:, f-1, 2] * rot[:, f-1, 2, 2])

D = Mt[:, -1, :].squeeze().reshape(-1, 1)
K = torch.mm(P, D);


In [6]:
# Create input vectors
preset = loadmat('./OC_input_inv_origin_uniform_v8.mat')
w1 = np.array(preset['w1'], dtype=np.float32)
w2 = np.array(preset['w2'], dtype=np.float32)
w1 = torch.FloatTensor(w1).unsqueeze(1)
w2 = torch.FloatTensor(w2).unsqueeze(1)
w1 = w1.repeat(batch,1,1)
w2 = w2.repeat(batch,1,1)
w1 = w1.to(device)
w2 = w2.to(device)


In [7]:
start_time = time.time()

sar_loss = []
L2_loss  = []
total_loss = []
stop = 0 

for e in range(max_N_iter):
    epoch_time = time.time()

    Mt = torch.zeros([batch, k, N_time_step+1, 3], device = device)
    Mt[:, :, 0, 2] = 1

    rot = (RF_simul(torch.transpose(w1,1,2), torch.transpose(w2,1,2), freq, T/N_time_step)).squeeze()
     
    for f in range(1 ,N_time_step+1):
        Mt[:, :, f, :]  = torch.sum(Mt[:,:, f-1, :].unsqueeze(-2) * rot[:,:, f-1, :, :] ,dim=-1)
   
    
#     for f in range(1 ,N_time_step+1):
#         Mt[:, :, f, 0] = Mt[:,:, f-1, 0] * rot[:,:, f-1, 0, 0] +  Mt[:,:, f-1, 1] * rot[:,:, f-1, 0, 1] +  Mt[:,:, f-1, 2] * rot[:, :,f-1, 0, 2] 
#         Mt[:, :, f, 1] = Mt[:,:, f-1, 0] * rot[:, :,f-1, 1, 0] +  Mt[:,:, f-1, 1] * rot[:,:, f-1, 1, 1] +  Mt[:,:, f-1, 2] * rot[:,:, f-1, 1, 2] 
#         Mt[:, :, f, 2] = Mt[:,:, f-1, 0] * rot[:, :,f-1, 2, 0] +  Mt[:,:, f-1, 1] * rot[:,:, f-1, 2, 1] +  Mt[:,:, f-1, 2] * rot[:, :,f-1, 2, 2] 
   
    Mt = Mt[:, :, 1:, :]
    M_T = Mt[:, :, -1, :].reshape(batch, -1, 1)
    
    lambda_T = torch.transpose((torch.matmul(torch.mm(torch.transpose(P, 0,1),P), M_T) - torch.mm(torch.transpose(P,0,1), K)),1,2).squeeze().to(device)
    lambda_ = torch.zeros((batch, k, N_time_step, 3), device = device)
    lambda_[:, :, -1, 0] = lambda_T[:, 0:3*k:3]
    lambda_[:, :, -1, 1] = lambda_T[:, 1:3*k:3]
    lambda_[:, :, -1, 2] = lambda_T[:, 2:3*k:3]

    for b in range(N_time_step-1, 0, -1):
        lambda_[:, :, b-1, :] = torch.sum(lambda_[:,:, b, :].unsqueeze(-1) * rot[:, :, b, :, :], dim = -2)
#         lambda_[:, :, b-1, 0] = lambda_[:, :, b, 0] * rot[:, :, b, 0, 0] + lambda_[:, :, b, 1] * rot[:, :, b, 1, 0] + lambda_[:, :, b, 2] * rot[:, :, b, 2, 0]
#         lambda_[:, :, b-1, 1] = lambda_[:, :, b, 0] * rot[:, :, b, 0, 1] + lambda_[:, :, b, 1] * rot[:, :, b, 1, 1] + lambda_[:, :, b, 2] * rot[:, :, b, 2, 1]
#         lambda_[:, :, b-1, 2] = lambda_[:, :, b, 0] * rot[:, :, b, 0, 2] + lambda_[:, :, b, 1] * rot[:, :, b, 1, 2] + lambda_[:, :, b, 2] * rot[:, :, b, 2, 2]
        
    # Update rf pulse
    for u in range(N_time_step):
        M = Mt[:, :,  u, :].squeeze().reshape(batch, -1, 1)
        lamb = lambda_[:, :, u, :].squeeze().reshape(batch, -1, 1)
        w1[:, :, u] = w1[:, :, u] - mu * (torch.bmm(torch.matmul(torch.transpose(lamb, 1, 2), B1), M).squeeze(1) + alpha * w1[:, :, u])
        w2[:, :, u] = w2[:, :, u] - mu * (torch.bmm(torch.matmul(torch.transpose(lamb, 1, 2), B2), M).squeeze(1) + alpha * w2[:, :, u])

    M_ = (Mt[:, :, -1, :]).squeeze().reshape(batch,-1,  1)
    dist = torch.matmul(P, M_) - K
    phi_d = (1/2 * (torch.matmul(torch.transpose(dist,1, 2), dist))).squeeze() 
    energy = torch.sum(w1**2 + w2**2, dim=2).squeeze()
    sar_loss.append(energy.detach().cpu().numpy())
    L2_loss.append(phi_d.detach().cpu().numpy())
    loss = phi_d + alpha * energy
    
    if e == 0 or (e +1) % 50 == 0:
        time_step = 5.12 * 1e-3 / 256 
        to_gauss = 2 * np.pi * 42.5775 * 1e+6 * time_step * 1e-4
        sar_slr = np.sum((slr/to_gauss)**2) * time_step * 1e+6
        sar_rf = torch.sum((w1**2 + w2 **2), dim = 2) / (to_gauss**2) * time_step * 1e+6
        print("★★★★★★★★★★★★★★★★★★★★★★★★")
        #print(f"SAR of SLR : {sar_slr:.5f}")
        #print(f"SAR of RF : {torch.min(sar_rf):.5f}")
        print(f"SAR reduction : {(sar_slr - torch.min(sar_rf))/sar_slr *100:.4f} %")
        idx = torch.argmin(loss)
        savemat(f"./output/test/freqtest_{e}.mat", {"L2loss": L2_loss, "sar_loss" : sar_loss, "w1" : w1.detach().cpu().numpy(), "w2" : w2.detach().cpu().numpy(), "total_loss" : total_loss, "iteration" : e, "mu" : mu})
        print(f"iteration = {e}, phi_d = {phi_d[idx].item():.5f}, energy = {energy[idx].item():.5f}")
        print(f"Execution time: {time.time() - epoch_time:.3f}")    


    #print(total_loss[-1][idx])
#     if (e > 500) :
#         if (loss[idx] > total_loss[-1][idx]) :
#             stop = stop + 1
#         else :
#             stop = 0 
        
#         if(stop>5):
#             break
    total_loss.append(loss.detach().cpu().numpy())
        
print(f"Total time: {time.time() - start_time:.1f}")

★★★★★★★★★★★★★★★★★★★★★★★★
SAR reduction : -14.9933 %
iteration = 0, phi_d = 3.38926, energy = 0.32254
Execution time: 0.564
★★★★★★★★★★★★★★★★★★★★★★★★
SAR reduction : -12.8579 %
iteration = 49, phi_d = 0.02392, energy = 0.31655
Execution time: 0.513
★★★★★★★★★★★★★★★★★★★★★★★★
SAR reduction : -12.4797 %
iteration = 99, phi_d = 0.00939, energy = 0.31549
Execution time: 0.513
★★★★★★★★★★★★★★★★★★★★★★★★
SAR reduction : -12.1726 %
iteration = 149, phi_d = 0.00601, energy = 0.31463
Execution time: 0.508


KeyboardInterrupt: 

In [8]:
savemat("OC_result_inv8_rl_batch.mat", {"L2loss": L2_loss, "sar_loss" : sar_loss, "w1" : w1.detach().cpu().numpy(), "w2" : w2.detach().cpu().numpy(), "total_loss" : total_loss, "iteration" : e, "mu" : mu})

In [9]:
print(w1.shape)

torch.Size([256, 1, 256])
