In [None]:
# %matplotlib inline
# %matplotlib notebook
import cv2
import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
# os.environ["CUDA_VISIBLE_DEVICES"]= "1"  # Set the GPU 2 to use

import torch
import numpy as np
import torch.nn.functional as F
import math
from torch import nn, optim
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch.utils.data import  TensorDataset, DataLoader
print(torch.cuda.is_available())
print(torch.__version__)

device = torch.device(2 if torch.cuda.is_available() else "cpu")

print('Device:', device)
print('Current cuda device:', torch.cuda.current_device())
print('Count of using GPUs:', torch.cuda.device_count())

import scipy.io
import cmath                
from skimage.metrics import structural_similarity as ssim
# import utills

from torch.utils.tensorboard import SummaryWriter

# 기본 `log_dir` 은 "runs"이며, 여기서는 더 구체적으로 지정하였습니다
writer = SummaryWriter('runs/experiment1')
import nibabel as nib
from sklearn.linear_model import LinearRegression
import pydicom as dcm

import pandas as pd
from nilearn.glm.first_level import make_first_level_design_matrix
from nilearn.glm.first_level import FirstLevelModel
from nilearn import plotting
from nilearn.image import concat_imgs, mean_img, resample_img
import nilearn

from utils.nifti_utils import *
from utils.common import *
from utils.fmri_utils_true import *


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ComplexConv(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
        super(ComplexConv,self).__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.padding = padding

        ## Model components
        self.conv_re = nn.Conv2d(in_channel, out_channel, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.conv_im = nn.Conv2d(in_channel, out_channel, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        
    def forward(self, x): # shpae of x : [batch,2,channel,axis1,axis2]
        real = self.conv_re(x[:,0]) - self.conv_im(x[:,1])
        imaginary = self.conv_re(x[:,1]) + self.conv_im(x[:,0])
        output = torch.stack((real,imaginary),dim=1)
        return output
        
#%%
if __name__ == "__main__":
    ## Random Tensor for Input
    ## shape : [batchsize,2,channel,axis1_size,axis2_size]
    ## Below dimensions are totally random
    x = torch.randn((10,2,3,100,100))
    
    # 1. Make ComplexConv Object
    ## (in_channel, out_channel, kernel_size) parameter is required
    complexConv = ComplexConv(3,10,(5,5))
    
    # 2. compute
    y = complexConv(x)

batchnorm=True
class UNet(nn.Module):
    def __init__(self, n_channels, out_ch, start_channel=64, alpha=1, bias=False, bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.bilinear = bilinear
        # self.inc=nn.Conv2d(n_channels, 64, kernel_size=1, padding=0, bias=bias)
        self.inc = (DoubleConv(n_channels, start_channel, bias=bias))
        self.down1 = (Down(start_channel, start_channel*2, bias=bias))
        self.down2 = (Down(start_channel*2, start_channel*4, bias=bias))
        self.down3 = (Down(start_channel*4, start_channel*8, bias=bias))
        factor = 2 if bilinear else 1
        self.down4 = (Down(start_channel*8, start_channel*16 // factor, bias=bias))
        self.down5 = (Down(start_channel*16, start_channel*32 // factor, bias=bias))
        self.up1 = (Up(start_channel*32, start_channel*16, bilinear, bias=bias))
        self.up2 = (Up(start_channel*16, start_channel*8, bilinear, bias=bias))
        self.up3 = (Up(start_channel*8, start_channel*4 // factor, bilinear, bias=bias))
        self.up4 = (Up(start_channel*4, start_channel*2 // factor, bilinear, bias=bias))
        self.up5 = (Up(start_channel*2, start_channel // factor, bilinear, bias=bias))
        
        self.outc = (OutConv(start_channel, out_ch, bias=False))
        # self.outc = (DoubleConv(start_channel, out_ch, bias=bias))
        self.alpha=alpha
    def forward(self, x):
        temp=x
        x1 = self.inc(temp)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x6 = self.down5(x5)

        temp = self.up1(x6, x5)
        temp = self.up2(temp, x4)
        temp = self.up3(temp, x3)
        temp = self.up4(temp, x2)
        temp = self.up5(temp, x1)
        logits = self.outc(temp)+x
        return logits

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels,kernel_size=3, mid_channels=None,bias=False):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        if batchnorm==True:
            self.double_conv = nn.Sequential(
                nn.Conv2d(in_channels, mid_channels, kernel_size=kernel_size, padding=1, bias=bias),
                nn.BatchNorm2d(mid_channels),
                nn.PReLU(),
                nn.Conv2d(mid_channels, out_channels, kernel_size=kernel_size, padding=1, bias=bias),
                nn.BatchNorm2d(out_channels),
                nn.PReLU()
            )
        else:
            self.double_conv = nn.Sequential(
                nn.Conv2d(in_channels, mid_channels, kernel_size=kernel_size, padding=1, bias=bias),
                nn.PReLU(),
                nn.Conv2d(mid_channels, out_channels, kernel_size=kernel_size, padding=1, bias=bias),
                nn.PReLU()
            )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels, bias=False):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, 2, stride=2, padding=0, bias=bias),
            DoubleConv(in_channels, out_channels, bias=bias)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=False, bias=False):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2, bias=bias)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2, bias=False)
            self.conv = DoubleConv(in_channels, out_channels, bias=bias)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels, bias=False):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1,bias=bias)

    def forward(self, x):
        return self.conv(x)

GE epi에서 SE epi로 가는 모델. 데이터는 하나로 오버피팅

In [None]:
## parameters
mod_num=80
center_slice=104
mask=False
dummy_measurement_num=10
# n=4
epoch=100
cutoff_percentage=0
threshold=3
## outputs
np_image_list=[]
np_mask_list=[]
slice_scale_z_map_nonprocessed_list=[]
slice_scale_z_map_processed_list=[]

AP_cor=nib.load('/home/milab/SSD_8TB/LeeSooHyung/datasets/fmri_inhomo_test_2/2024.06.03/JKJin/Stroop/Stroop_task_HC_SE_EPI_AP.nii')
GE_39ms=nib.load('/home/milab/SSD_8TB/LeeSooHyung/datasets/fmri_inhomo_test_2/2024.06.03/JKJin/Stroop/Stroop_task_HC_2mm.nii').get_fdata()[:,:,:]
SE_EPI_original=AP_cor.get_fdata()[:,:,:]
GE_EPI=GE_39ms
GE_EPI_avg=0
for i in range(dummy_measurement_num):
    GE_EPI_avg=GE_EPI_avg+GE_EPI[:,:,:,i]
GE_EPI_avg=GE_EPI_avg/dummy_measurement_num
for slice_num in range(GE_EPI.shape[2]):
    print("Slice_num:",slice_num)

    ################### MPRAGE PLOT #####################   
    # check_img(mprage[:,:,slice_num*2],cbar=False,norm=True,rotate=True)
    # check_img(mprage[:,:,slice_num*2+1],cbar=False,norm=True,rotate=True)
    ##################################################### 
    

    AP_mask_sliced=make_mask_sliced_direct(AP_cor,slice_num)
    if mask==False:
        AP_mask_sliced=copy_nii_header(np.ones(AP_mask_sliced.get_fdata().shape),AP_mask_sliced)
    np_mask_list.append(AP_mask_sliced.get_fdata())
    # check_img(GE_EPI_avg[:,:,slice_num])

    SE_EPI=SE_EPI_original[:,:,slice_num]
    GE_EPI=GE_EPI_avg[:,:,slice_num]
    # print(GE_EPI.shape)
    SE_list=[]
    GE_list=[]

    for i in range(6):
        SE_list.append(SE_EPI)
        GE_list.append(GE_EPI)
    SE_EPI=np.stack(SE_list)
    GE_EPI=np.stack(GE_list)

    array_input=torch.Tensor(GE_EPI).float()
    array_label=torch.Tensor(SE_EPI).float()
    
    train_input=torch.unsqueeze(array_input[:,:,:],dim=1)
    train_label=torch.unsqueeze(array_label[:,:,:],dim=1)
    print(train_label.shape)
    test=GE_39ms[:,:,slice_num,10:]
    # print(test.shape)
    invivo_input=torch.permute(torch.unsqueeze(torch.tensor(test),dim=2),(3,2,0,1)).float()

    dataset_2ch=TensorDataset(train_input
                            ,train_label
                            )

    invivo=TensorDataset(invivo_input,
                        invivo_input)

    traindataset=DataLoader(dataset_2ch, batch_size=1, shuffle=True)
    valdataset=DataLoader(dataset_2ch, batch_size=1, shuffle=True)
    invivodataset=DataLoader(invivo, batch_size=1, shuffle=False)

    bold_no_pre=dataset_to_tensor(invivodataset)
    bold_no_pre=torch.unsqueeze(torch.permute(torch.squeeze(bold_no_pre),(1,2,0)),2)

    model_1=UNet(2,2,64)
    torch.cuda.empty_cache()
    torch.cuda.memory_summary
    model_1=model_1.to(device)

    loss_function_k=torch.nn.L1Loss()

    optimizer_1=optim.AdamW(model_1.parameters(), lr=0.001, weight_decay=0)
    scheduler=optim.lr_scheduler.StepLR(optimizer_1,100,gamma=0.8)

    loss_function_img=torch.nn.MSELoss()
    norm_scale=1000

    check_num=epoch
    i=0
    for i in range(epoch):
            
        total, total_loss = train_model(model_1, traindataset, optimizer_1, loss_function_k, loss_function_img, device, norm_scale, writer, epoch_number=i,noise=None, frequency_loss=True, image_loss=False,l1norm=None,tv_loss=None)
        scheduler.step()
        if i%check_num==check_num-1:
            bold=testset_output(model_1,invivodataset,device,norm_scale)
            bold_per=abs(torch.unsqueeze(torch.permute(bold,(1,2,0)),2).float())
            np_image_list.append(bold_per.numpy())





In [None]:
np_image_list[0].shape
np_image_array=np.stack(np_image_list,axis=2)
np.squeeze(np_image_array).shape

preprocessed_nii=copy_nii_header(np.squeeze(np_image_array),AP_cor)
nib.save(preprocessed_nii,'2024.06.03.JKJ_stroop.nii')
# nib.save(preprocessed_nii,'2024.06.04.KJYoon_visual.nii')


In [None]:
z_map,data=calculate_glm(preprocessed_nii,nilearn.masking.compute_epi_mask(AP_cor),percentage=cutoff_percentage)