# Setting

In [None]:
import ast
import csv
import glob
from io import StringIO
import json
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pydicom
import time
from tqdm import tqdm_notebook
from tqdm import trange

from source.data_loader import MyDataLoader, WND
from source.my3dpix2pix import My3dPix2Pix

# Check inference result from training or validation set

In [None]:
## load config
spath = 'result/YOURFOLDER'

with open(os.path.join(spath,'cfg.json')) as json_file:
    cfg = json.load(json_file)

In [None]:
df0 = pd.read_feather(cfg['df_path'])
DL = MyDataLoader(df0, cts=cfg['cts'], img_shape=cfg['img_shape'],\
                        grid=cfg['grid'],\
                        window1=cfg['window1'], window2=cfg['window2'], rescale_intensity=cfg['rescale_intensity'], splitvar=cfg['splitvar'])

gan = My3dPix2Pix(DL, savepath=spath, L_weights=cfg['L_weights'], opt=cfg['opt'], lrs=cfg['lrs'],\
                        smoothlabel=cfg['smoothlabel'], fmloss=cfg['fmloss'],\
                        gennoise=cfg['gennoise'],\
                        randomshift=cfg['randomshift'], resoutput=cfg['resoutput'], dropout=cfg['dropout'],\
                        coordconv=cfg['coordconv'], resizeconv=cfg['resizeconv'], multigpu=cfg['multigpu'])

gan.load_final_weights()

In [None]:
## check which cases are in training (0) or validation set (1)
splitset = 0
DL.case_split[splitset]

In [None]:
## run model on single case
case = 0 # your case number
pos = (0, 0, 8) # x,y fixed to 0, change z-axis number

A, B = DL.imread_slice(case, pos, window=True, split=splitset)
imgs_A = np.array([A])/127.5 - 1.
imgs_B = np.array([B])/127.5 - 1.

%time fake_A = gan.generator.predict(imgs_B)
fake_AA = gan.invert_resoutput(fake_A, imgs_B)

C = fake_A - imgs_B
C[C<0] = 0
fake_A = imgs_B+C

gen_imgs = np.concatenate([imgs_B, fake_A, imgs_A])
gen_imgs = 0.5 * gen_imgs + 0.5

In [None]:
## results
r, c = 3, 3

titles = ['Condition', 'Generated', 'Original']
plt.style.use('default')
fig, axs = plt.subplots(r, c, figsize=(3*c,3*r))
for i in range(c):
    for j in range(r):
        fig0 = axs[j,i].imshow(gen_imgs[i][:,:,j,0], cmap='gray')
        if j==0:
            axs[j,i].set_title(titles[i])
        fig.colorbar(fig0, ax=axs[j,i])
plt.show()
plt.close(fig)

# Test trained model on new test set

## Dataframe from dicoms
dicom/YOURDATASET should be in following format:<br>
YOURDATASET<br>
&nbsp;&nbsp;&nbsp;&nbsp;case1<br>
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;CT1 containing dicom files<br>
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;CT2 containing dicom files<br>
&nbsp;&nbsp;&nbsp;&nbsp;case2<br>
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;CT1 containing dicom files<br>
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;CT2 containing dicom files<br>
&nbsp;&nbsp;&nbsp;&nbsp;...

In [None]:
def my_dicoms_to_dataframe(basedir, cts):
    caselist = [os.path.join(basedir, x) for x in os.listdir(basedir) if os.path.isdir(os.path.join(basedir, x))]
    file_list = []
    for x in cts:
        file_list.extend(glob.glob(os.path.join(basedir, '*/'+x+'/*.*')))

    tdcmpath = os.path.join(caselist[0], cts[0])
    tdcmpath = [os.path.join(tdcmpath, x) for x in os.listdir(tdcmpath) if x.lower().endswith('.dcm')][0]
    tdcm = pydicom.dcmread(tdcmpath)

    headers = []
    headers.append('filepath')

    for x in tdcm:
        if x.name == 'Pixel Data':
            continue
        elif 'Overlay' in x.name or 'Referring' in x.name or 'Acquisition' in x.name:
            continue
        else:
            name = x.name.replace(' ', '')
            headers.append(name)

    output = StringIO()
    csv_writer = csv.DictWriter(output, fieldnames=headers)
    csv_writer.writeheader()

    for f in tqdm_notebook(file_list):
        file = pydicom.dcmread(f)

        row = {}
        for x in file:
            row['filepath'] = f
            if x.name == 'Pixel Data':
                continue
            elif 'Overlay' in x.name or 'Referring' in x.name or 'Acquisition' in x.name:
                continue
            else:
                name = x.name.replace(' ', '')
                row[name] = x.value
        unwanted = set(row) - set(headers)
        for unwanted_key in unwanted: del row[unwanted_key]
        csv_writer.writerow(row)

    output.seek(0) # we need to get back to the start of the StringIO
    df = pd.read_csv(output)

    df['pid'] = df['filepath'].apply(lambda x: x.split(os.sep)[-3])
    df['ct'] = df['filepath'].apply(lambda x: x.split(os.sep)[-2])
    df['zpos'] = df['ImagePosition(Patient)'].apply(lambda x: [n.strip() for n in ast.literal_eval(x)][-1])

    cols = df.columns.tolist()
    cols = cols[-3:] + cols[:-3]
    df = df[cols]

    df.to_feather(os.path.join(basedir, 'headers.ftr'))
    return df

In [None]:
# CT1 = input, CT2 = output

basedir = 'dicom/YOURTESTSET'
cts = ('CT1','CT2')
df = my_dicoms_to_dataframe(basedir, cts)

In [None]:
#### modify headers and save

df['zpos'] = df['zpos'].apply(pd.to_numeric)
df = df.sort_values(by=['pid', 'ct', 'zpos'])
df2 = df.reset_index(drop=True)
df2path = os.path.join(basedir, 'select.ftr')
df2.to_feather(df2path)

## run model

In [None]:
## loop for inference on all cases in test set

def loop_over_case(case, notruth=False):

    pid, zs = case

    dcm_A, dcm_B = gan.data_loader.load_dicoms(pid, (0,zs+1))
    if notruth:
        dcm_A = np.zeros(dcm_B.shape, dtype=dcm_B.dtype)

    a = []
    b = []
    for w in gan.data_loader.window2:
        a.append(WND(dcm_A,w))
    for w in gan.data_loader.window1:
        b.append(WND(dcm_B,w))  
    tot_A = np.stack(a, axis=-1)
    tot_B = np.stack(b, axis=-1)
    tot_A = tot_A.astype('float32')/127.5 - 1.
    tot_B = tot_B.astype('float32')/127.5 - 1.

    fakes_raw = np.full((gan.img_rows,gan.img_cols,zs),0,dtype=tot_B.dtype)
    counts_raw = np.full((gan.img_rows,gan.img_cols,zs),0,dtype=int)

    for i in tqdm_notebook(range(zs+1-gan.depth)):
        imgs_B = np.expand_dims(tot_B[:,:,i:i+gan.depth,:], axis=0)
        fake_A = gan.generator.predict(imgs_B)
        fake_A = 0.5 * fake_A + 0.5
        fake_A = rWND(255.*fake_A[:,:,:,:,0], gan.data_loader.window2[0])

        fakes_raw[:,:,i:i+gan.depth] += fake_A[0]
        counts_raw[:,:,i:i+gan.depth] += 1

    mcounts = counts_raw.copy()
    mcounts[mcounts==0] = 1
    fakes = np.divide(fakes_raw, mcounts)

    # random sample
    sample = np.random.choice(fakes.shape[-1])
    sample = np.stack((
        dcm_B[:,:,sample].astype(fakes.dtype),
        fakes[:,:,sample],
        dcm_A[:,:,sample].astype(fakes.dtype)
    ), axis=-1)

    df1 = gan.data_loader.df
    dcms1 = df1[(df1['pid']==pid)&(df1['ct']==gan.data_loader.cts[0])]['filepath'].tolist()

    newpath = os.path.join(savedir, pid)
    if not os.path.isdir(newpath):
        os.mkdir(newpath)
    newpath = os.path.join(newpath, 'dicom')
    if not os.path.isdir(newpath):
        os.mkdir(newpath)

    for N, y in tqdm_notebook(enumerate(dcms1)):
        x = fakes[:,:,N]
        ds = pydicom.dcmread(y)

        x = (x-float(ds.RescaleIntercept))/float(ds.RescaleSlope)

        x = x.astype('int16')

        ds.PixelData = x.tobytes()

        ds.SeriesNumber += 99000
        ds.SOPInstanceUID += '.99'

        newfile = os.path.join(newpath, os.path.basename(y))
        ds.save_as(newfile)

    return sample

In [None]:
%%time
## load config + get new dicoms
spath = 'result/YOURFOLDER'

with open(os.path.join(spath,'cfg.json')) as json_file:
    cfg = json.load(json_file)

# your own test set and names of ct folders
cfg['df_path'] = 'dicom/YOURTESTSET/select.ftr'
cfg['cts'] = ('CT1','CT2')
cfg['splitvar'] = 1.0  # fixed
    
df0 = pd.read_feather(cfg['df_path'])
%time DL = MyDataLoader(df0, cts=cfg['cts'], img_shape=cfg['img_shape'],\
                grid=cfg['grid'],\
                window1=cfg['window1'], window2=cfg['window2'], rescale_intensity=cfg['rescale_intensity'], splitvar=cfg['splitvar'])

%time gan = My3dPix2Pix(DL, savepath=spath, L_weights=cfg['L_weights'], opt=cfg['opt'], lrs=cfg['lrs'],\
                       smoothlabel=cfg['smoothlabel'], fmloss=cfg['fmloss'],\
                       gennoise=cfg['gennoise'],\
                       randomshift=cfg['randomshift'], resoutput=cfg['resoutput'], dropout=cfg['dropout'],\
                       coordconv=cfg['coordconv'], resizeconv=cfg['resizeconv'], multigpu=cfg['multigpu'])

%time gan.load_final_weights()

In [None]:
## make directory for test results inside result/YOURFOLDER
savedir = gan.make_directory('TESTDIRECTORY')
split = 0
L = gan.data_loader.case_split[split]
choice = np.arange(len(L))

In [None]:
## run loop
samples = []

for case in tqdm_notebook(choice):
    samples.append(loop_over_case(L[case], notruth=False))

In [None]:
r = len(samples)
c = 3

titles = ['Condition', 'Generated', 'Original']
plt.style.use('default')
fig, axs = plt.subplots(r, c, figsize=(3*c,3*r))
for i in range(c):
    axs[0,i].set_title(titles[i])
    for j in range(r):
        fig0 = axs[j,i].imshow(samples[j][:,:,i], cmap='gray')
        fig.colorbar(fig0, ax=axs[j,i])
plt.show()
plt.close(fig)