In [1]:
import h5py
import numpy as np
import keras
import sys
import matplotlib.pyplot as plt
import pandas as pd
import sys
import common
import DataProcessing as dp
import TrainModel as m
import TrainAndTest as tt
import tensorflow as tf
import tensorflow.keras.layers as tfl
from tensorflow.keras.optimizers import Adam
from keras.metrics import RootMeanSquaredError
import matplotlib.pyplot as plt
from prettytable import PrettyTable


In [2]:
def processFullMRIFileRFA0(noise, testNoise=None):
    sample,b = common.readAllAcqs(noise)
    sig1 = sample[:,:16,:]
    noise1 = sample[:,16:18,:]
    l = sig1.shape[0]
    l1 = int(l*0.8)
    sigTrain = dp.SplitComplexR(sig1[0:l1])
    noiseTrain = dp.SplitComplexR(noise1[0:l1])
    sigVal = dp.SplitComplexR(sig1[l1:l])
    noiseVal = dp.SplitComplexR(noise1[l1:l])
    sigTest = dp.SplitComplexR(sig1)
    noiseTest = dp.SplitComplexR(noise1)

    if testNoise != None :
        test,c = common.readAllAcqs(testNoise)
        sigT = test[:,:16,:]
        noiseT = test[:,16:18,:]
        sigTest = dp.SplitComplexR(sigT)
        noiseTest = dp.SplitComplexR(noiseT)

    return(sigTrain,noiseTrain,sigVal,noiseVal,sigTest,noiseTest)

def mean(sig):
    return np.mean(sig)

def peak(sig):
     return np.max(sig)

def std(sig):
     return np.std(sig, dtype=np.float64)

def experiment2DTable(meanBefore,peakBefore,stdBefore,meanSupComb,peakSupComb,stdSupComb,SRMeanSupComb,SRPeakSupComb,SRStdSupComb):
    table = PrettyTable(["","Before suppression","After suppression with Channel 16 and 17",
                          "Suppression Rate3"])
    table.add_row(["mean",meanBefore,meanSupComb,SRMeanSupComb]) 
    table.add_row(["peak",peakBefore,peakSupComb,SRPeakSupComb]) 
    table.add_row(["standard deviation",stdBefore,stdSupComb,SRStdSupComb]) 
    print(table)

def experiment2DCalculation(before,supComb):
    meanBefore = mean(before)
    peakBefore = peak(before)
    stdBefore = std(before)
    meanSupComb = mean(supComb)
    peakSupComb = peak(supComb)
    stdSupComb = std(supComb)
    SRMeanSupComb = 1 - (np.abs(meanSupComb)/np.abs(meanBefore))
    SRPeakSupComb = 1 - (np.abs(peakSupComb)/np.abs(peakBefore))
    SRStdSupComb = 1 - (np.abs(stdSupComb)/np.abs(stdBefore))
    experiment2DTable(meanBefore,peakBefore,stdBefore,meanSupComb,peakSupComb,stdSupComb,SRMeanSupComb,SRPeakSupComb,SRStdSupComb)
    #return meanBefore,peakBefore,stdBefore,meanSupComb,peakSupComb,stdSupComb,SRMeanSupComb,SRPeakSupComb,SRStdSupComb

In [22]:
sigTrain = []
noiseTrain = []
sigVal = []
noiseVal = []
sigTest = []
noiseTest = []

## Define file paths
date = '20250122'
mode = 'AM'
type = 'Square'
baseline = 'C:/JiaxingData/EMINoise/'+date+'/'+mode+type+'BFA77.h5'
for trial in ['1','2','3','4','5','6','7','8','9','10']:
    noise = 'C:/JiaxingData/EMINoise/'+date+'/'+mode+type+'SFA0_'+trial+'.h5'
    test = 'C:/JiaxingData/EMINoise/'+date+'/'+mode+type+'SFA30_'+trial+'.h5'

    ## pre-processing data
    sigTrainC,noiseTrainC,sigValC,noiseValC,sigTestC,noiseTestC = processFullMRIFileRFA0(noise,test)
    for i in range(sigTrainC.shape[0]):
        sigTrain.append(sigTrainC[i])
        noiseTrain.append(noiseTrainC[i])
    for j in range(sigValC.shape[0]):
        sigVal.append(sigValC[j])
        noiseVal.append(noiseValC[j])
    for k in range(sigTestC.shape[0]):
        sigTest.append(sigTestC[k])
        noiseTest.append(noiseTestC[k])

sigTrain = np.array(sigTrain)
noiseTrain = np.array(noiseTrain)
sigVal = np.array(sigVal)
noiseVal = np.array(noiseVal)
sigTest = np.array(sigTest)
noiseTest = np.array(noiseTest)


In [25]:

### Model training
bs = 16
epoch_num = 400
lr = 0.0002

N = np.array(noiseTrain).shape[2]
model = m.get_model(N)

def lrDeacy(epoch):
    return lr*0.9**(epoch//8)
# learning rate uodate callback
LRC = tf.keras.callbacks.LearningRateScheduler(lrDeacy)

# early stopping callback, stop the training if the validation loss stop redcuing
ESC = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',   # Metric to be monitored
    patience=10,           # Number of epochs to wait for improvement
    restore_best_weights=True # Restore model weights from the epoch with the best value of the monitored metric
    )
model.summary()
model.compile(optimizer=Adam(learning_rate = lr), loss = "mse",metrics = [RootMeanSquaredError()])
history = model.fit(noiseTrain, sigTrain, epochs = epoch_num, 
                    validation_data = (noiseVal,sigVal),
                    callbacks = [LRC,ESC], batch_size = bs)

# model name format: 
# model_xxx(batch size)_xxx(feature s.a.image,signal only, flip angle 0)_x(subject(s) or object(o))
modelName = 'Model'
#model.save('C:/JiaxingData/EMINoise/'+date+'/'+modelName+'.h5')

#plot corresponding loss and validation loss
#tt.plotHist("loss",history,modelName)
#tt.plotHist("val_loss",history,modelName)





Epoch 1/400
[1m256/256[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 34ms/step - loss: 58459.0391 - root_mean_squared_error: 241.7807 - val_loss: 56053.5820 - val_root_mean_squared_error: 236.7564 - learning_rate: 2.0000e-04
Epoch 2/400
[1m256/256[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 32ms/step - loss: 54508.2891 - root_mean_squared_error: 233.4662 - val_loss: 50842.7695 - val_root_mean_squared_error: 225.4834 - learning_rate: 2.0000e-04
Epoch 3/400
[1m256/256[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 32ms/step - loss: 49955.4453 - root_mean_squared_error: 223.5018 - val_loss: 45684.5352 - val_root_mean_squared_error: 213.7394 - learning_rate: 2.0000e-04
Epoch 4/400
[1m256/256[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 32ms/step - loss: 45078.3555 - root_mean_squared_error: 212.3099 - val_loss: 40699.8945 - val_root_mean_squared_error: 201.7421 - learning_rate: 2.0000e-04
Epoch 5/400
[1m256/256[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m

In [5]:
def mean(sig):
    return np.mean(sig)

def peak(sig):
     return np.max(sig)

def std(sig):
     return np.std(sig, dtype=np.float64)

def experiment2DTable(meanBefore,peakBefore,stdBefore,meanSupComb,peakSupComb,stdSupComb,SRMeanSupComb,SRPeakSupComb,SRStdSupComb):
    table = PrettyTable(["","Before suppression","After suppression with Channel 16 and 17",
                          "Suppression Rate3"])
    table.add_row(["mean",meanBefore,meanSupComb,SRMeanSupComb]) 
    table.add_row(["peak",peakBefore,peakSupComb,SRPeakSupComb]) 
    table.add_row(["standard deviation",stdBefore,stdSupComb,SRStdSupComb]) 
    print(table)

def experiment2DCalculation(before,supComb):
    meanBefore = mean(before)
    peakBefore = peak(before)
    stdBefore = std(before)
    meanSupComb = mean(supComb)
    peakSupComb = peak(supComb)
    stdSupComb = std(supComb)
    SRMeanSupComb = 1 - (np.abs(meanSupComb)/np.abs(meanBefore))
    SRPeakSupComb = 1 - (np.abs(peakSupComb)/np.abs(peakBefore))
    SRStdSupComb = 1 - (np.abs(stdSupComb)/np.abs(stdBefore))
    experiment2DTable(meanBefore,peakBefore,stdBefore,meanSupComb,peakSupComb,stdSupComb,SRMeanSupComb,SRPeakSupComb,SRStdSupComb)
    #return meanBefore,peakBefore,stdBefore,meanSupComb,peakSupComb,stdSupComb,SRMeanSupComb,SRPeakSupComb,SRStdSupComb

In [26]:
### Testing

def testProcessing(test):
    test,c = common.readAllAcqs(test)
    sigT = test[:,:16,:]
    noiseT = test[:,16:18,:]
    sigTest = dp.SplitComplexR(sigT)
    noiseTest = dp.SplitComplexR(noiseT)
    return sigTest,noiseTest

date = '20250122'
for mode in ['AM']:
    for type in ['Square']:
        #'1','2','3','4','5','6','7','8','9','10','11','12','13','14','15','16','17','18','19','20','21','22','23','24','25'
        for trial in ['1','2','3','4','5','6','7','8','9','10']:
            sigTest,noiseTest = testProcessing('C:/JiaxingData/EMINoise/'+date+'/'+mode+type+'SFA30_'+trial+'.h5')
            # predicting the noise map using the model obtained above
            predicted = model.predict(noiseTest)
            corrected = sigTest - predicted
            #mse = np.mean(np.square(corrected - baseline))#[subLength:length]))
            #print("Mean square Error on sample Image:", mse)
# Initialize array for cleaned data
#sampleCorrected = np.zeros((length,16,nPoints),dtype=np.complex64)
#sampleCorrected[0:subLength] = sample[0:subLength]
#sampleCorrected[subLength:length] = corrected

# Convert data into complex values and then k-space and image space
            corrected = dp.ConvergeComplexR(corrected)
            #baseline = tt.ConvergeComplex(baseline)
            sampleMRI = dp.ConvergeComplexR(sigTest)
            #NoiseMap = dp.toImg(dp.toKSpace(tt.ConvergeComplexR(predicted),noise))
            #CleanImg = dp.toImg(dp.toKSpace(corrected,noise))
            #NoisyImg = dp.toImg(dp.toKSpace(sampleMRI,noise))
            #BaselineImg = dp.toImg(dp.toKSpace(baseline,baseline))

            corrected = np.mean(corrected,axis = 0)
            sampleMRI = np.mean(sampleMRI,axis = 0)
            experiment2DCalculation(sampleMRI,corrected)

            # Creating new h5py file that stores the noise removed data
        #tt.storePrediction(fPath,noiseName,dp.complexRearrangement(corrected))
        #tt.storePrediction18To16(fPath, noiseName)

            # plot the results
            #tt.plotAll(BaselineImg,NoisyImg,CleanImg,NoiseMap)
            #tt.plotSamples(CleanImg,NoisyImg,BaselineImg,date)
            #tt.plotFFTComparsion(corrected,sampleMRI,baseline,date)

[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step
+--------------------+----------------------------+------------------------------------------+--------------------+
|                    |     Before suppression     | After suppression with Channel 16 and 17 | Suppression Rate3  |
+--------------------+----------------------------+------------------------------------------+--------------------+
|        mean        | (0.046039842+0.078493305j) |       (-0.035626553-0.022560803j)        |     0.5365985      |
|        peak        |   (38.273678+3.9628289j)   |          (3.1789117+2.058277j)           |     0.9015787      |
| standard deviation |     16.561834709451578     |            0.5831182396559996            | 0.9647914467276247 |
+--------------------+----------------------------+------------------------------------------+--------------------+
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step
+--------------------+------------------

In [20]:
def preScanProcessing(test):
    sig = np.zeros((37,16,1024))
    noise = np.zeros((37,2,1024))
    test,c = common.readAllAcqs(test,table_name="noise")
    sigT = test[:,:16,:]
    for i in range(37):
        sig[i] = sigT[:,:,i*1024:(i+1)*1024] 
    noiseT = test[:,16:18,:]
    for i in range(37):
        noise[i] = noiseT[:,:,i*1024:(i+1)*1024]
    sig = dp.SplitComplexR(sig)
    noise = dp.SplitComplexR(noise)
    sig = np.squeeze(sig)
    noise = np.squeeze(noise)
    sigTest = sig[0:28]
    noiseTest = noise[0:28]
    sigVal = sig[28:37]
    noiseVal = noise[28:37]
    return sigTest,noiseTest,sigVal,noiseVal


class Callback(tf.keras.callbacks.Callback):
    SHOW_NUMBER = 10
    counter = 0
    epoch = 0

    def on_epoch_begin(self, epoch, logs=None):
        self.epoch = epoch

    def on_train_batch_end(self, batch, logs=None):
        if self.counter == self.SHOW_NUMBER or self.epoch == 1:
            print('Epoch: ' + str(self.epoch) + ' loss: ' + str(logs['loss']))
            if self.epoch > 1:
                self.counter = 0
        self.counter += 1

date = '20250122'
for mode in ['AM']:
    for type in ['Square']:
        #'1','2','3','4','5','6','7','8','9','10','11','12','13','14','15','16','17','18','19','20','21','22','23','24','25'
        for trial in ['1','2','3','4','5','6','7','8','9','10']:
            noise =  'C:/JiaxingData/EMINoise/'+date+'/'+mode+type+'SFA30_'+trial+'.h5'
            sigTrain,noiseTrain,sigVal,noiseVal = preScanProcessing(noise)
            ### Model training
            bs = 4
            epoch_num = 1000
            lr = 0.0002

            N = np.array(noiseTrain).shape[2]
            model = m.get_model(N)

            def lrDeacy(epoch):
                return lr*0.9**(epoch//40)
            # learning rate uodate callback
            LRC = tf.keras.callbacks.LearningRateScheduler(lrDeacy)

            # early stopping callback, stop the training if the validation loss stop redcuing
            ESC = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',   # Metric to be monitored
    patience=5,           # Number of epochs to wait for improvement
    restore_best_weights=True # Restore model weights from the epoch with the best value of the monitored metric
    )
            #model.summary()
            model.compile(optimizer=Adam(learning_rate = lr), loss = "mse",metrics = [RootMeanSquaredError()])
            history = model.fit(noiseTrain, sigTrain, epochs = epoch_num,validation_data = (noiseVal,sigVal),verbose=0,
                    callbacks = [ESC], batch_size = bs)
            print(len(history.history['loss']))
        

            sigTest,noiseTest = testProcessing('C:/JiaxingData/EMINoise/'+date+'/'+mode+type+'SFA30_'+trial+'.h5')
            # predicting the noise map using the model obtained above
            predicted = model.predict(noiseTest)
            corrected = sigTest - predicted
            #mse = np.mean(np.square(corrected - baseline))#[subLength:length]))
            #print("Mean square Error on sample Image:", mse)
# Initialize array for cleaned data
#sampleCorrected = np.zeros((length,16,nPoints),dtype=np.complex64)
#sampleCorrected[0:subLength] = sample[0:subLength]
#sampleCorrected[subLength:length] = corrected

# Convert data into complex values and then k-space and image space
            corrected = dp.ConvergeComplexR(corrected)
            #baseline = tt.ConvergeComplex(baseline)
            sampleMRI = dp.ConvergeComplexR(sigTest)
            #NoiseMap = dp.toImg(dp.toKSpace(tt.ConvergeComplexR(predicted),noise))
            #CleanImg = dp.toImg(dp.toKSpace(corrected,noise))
            #NoisyImg = dp.toImg(dp.toKSpace(sampleMRI,noise))
            #BaselineImg = dp.toImg(dp.toKSpace(baseline,baseline))

            corrected = np.mean(corrected,axis = 0)
            sampleMRI = np.mean(sampleMRI,axis = 0)
            experiment2DCalculation(sampleMRI,corrected)
            

290
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step


  arrmean = umr_sum(arr, axis, dtype, keepdims=True, where=where)


+--------------------+----------------------------+------------------------------------------+----------------------+
|                    |     Before suppression     | After suppression with Channel 16 and 17 |  Suppression Rate3   |
+--------------------+----------------------------+------------------------------------------+----------------------+
|        mean        | (0.046039842+0.078493305j) |         (-2.5007937+0.24320744j)         |      -26.611126      |
|        peak        |   (38.273678+3.9628289j)   |          (49.301044-5.2035494j)          |     -0.28838623      |
| standard deviation |     16.561834709451578     |            19.167594858858077            | -0.15733523459930643 |
+--------------------+----------------------------+------------------------------------------+----------------------+
263
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step
+--------------------+-----------------------------+------------------------------------------+-

In [19]:
sig = np.zeros((37,16,1024))
noise = np.zeros((37,2,1024))
test,c = common.readAllAcqs('C:/JiaxingData/EMINoise/20250121/AMSquareFA77_'+trial+'.h5',table_name="noise")
sigT = test[:,:16,:]
for i in range(37):
    sig[i] = sigT[:,:,i*1024:(i+1)*1024] 
noiseT = test[:,16:18,:]
for i in range(37):
    noise[i] = noiseT[:,:,i*1024:(i+1)*1024] 
sigTest = dp.SplitComplexR(sig)
noiseTest = dp.SplitComplexR(noise)
sigTest = np.squeeze(sigTest)
noiseTest = np.squeeze(noiseTest)
print(noiseTest.shape)

(37, 2, 512, 2)
