In [1]:
import sys
import logging
import numpy as np
import nibabel as nib
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import monai
from monai.data import NiftiDataset
from monai.transforms import Compose, CastToType, SpatialPad, AddChannel, ScaleIntensity, Resize, RandRotate90, RandZoom, ToTensor
import os
import matplotlib.pyplot as plt
import glob

In [None]:
cov_pos_dir = '/home/marafath/scratch/bimcv/covid_pos/BIMCV-COVID19'
train_images = sorted(glob.glob(os.path.join(cov_pos_dir, "sub-S0*", "ses-E0*", "mod-rx", "*_ct.nii.gz")))

print(train_images[0:20]) # total = 2291

In [None]:
cov_neg_dir = '/home/marafath/scratch/bimcv/covid_neg/BIMCV-COVID19-Negative'
train_images = sorted(glob.glob(os.path.join(cov_neg_dir, "sub-S0*", "ses-E0*", "mod-rx", "*_ct.nii.gz")))

print(train_images[0:20])
print(len(train_images)) # 1364

In [21]:
data_dir = '/home/marafath/scratch/zenodo/mask'

for patient in os.listdir(data_dir):
    img = nib.load(os.path.join(data_dir,patient))
    data = img.get_fdata()

    #data = np.divide(data,255)
    #data = data - 1250

    data = np.round_(data)
    rescaled_data = nib.Nifti1Image(data, img.affine, img.header)
    nib.save(rescaled_data, os.path.join(data_dir,patient))

    #print(np.max(data))
    #print(np.min(data))

In [7]:
import copy
data_dir = '/home/marafath/scratch/zenodo/aug_data'
counter = 0

for patient in sorted(os.listdir(os.path.join(data_dir,'image'))):
    
    img = nib.load(os.path.join(data_dir,'image',patient))
    data = img.get_fdata()
    
    msk = nib.load(os.path.join(data_dir,'masks',patient.split("/")[-1]))
    label = msk.get_fdata()
    label = np.round_(label)

    if patient[0:11] == 'radiopaedia':
        data = np.divide(data,255)
        data = data*1500
        data = data - 1250
        
    label[label == 2] = 1
    label[label == 3] = 2
    
    # Original
    rescaled_data = nib.Nifti1Image(data, img.affine, img.header)
    nib.save(rescaled_data, os.path.join(data_dir,'image_'+str(counter)))
    
    rescaled_label = nib.Nifti1Image(label, msk.affine, msk.header)
    nib.save(rescaled_label, os.path.join(data_dir,'label_'+str(counter)))
    counter += 1
    
    # Vertical flips of original
    data2 = copy.deepcopy(data)
    data2 = np.flip(data2, axis=0)
    label2 = copy.deepcopy(label)
    label2 = np.flip(label2, axis=0)
    
    rescaled_data2 = nib.Nifti1Image(data2, img.affine, img.header)
    nib.save(rescaled_data2, os.path.join(data_dir,'image_'+str(counter)))
    
    rescaled_label2 = nib.Nifti1Image(label2, msk.affine, msk.header)
    nib.save(rescaled_label2, os.path.join(data_dir,'label_'+str(counter)))
    counter += 1
    
    # Horigontal flips of original
    data3 = copy.deepcopy(data)
    data3 = np.flip(data3, axis=1)
    label3 = copy.deepcopy(label)
    label3 = np.flip(label3, axis=1)
    
    rescaled_data3 = nib.Nifti1Image(data3, img.affine, img.header)
    nib.save(rescaled_data3, os.path.join(data_dir,'image_'+str(counter)))
    
    rescaled_label3 = nib.Nifti1Image(label3, msk.affine, msk.header)
    nib.save(rescaled_label3, os.path.join(data_dir,'label_'+str(counter)))
    counter += 1
    
    # Horigontal flips of vertically flipped
    data4 = copy.deepcopy(data2)
    data4 = np.flip(data4, axis=1)
    label4 = copy.deepcopy(label2)
    label4 = np.flip(label4, axis=1)
    
    rescaled_data4 = nib.Nifti1Image(data4, img.affine, img.header)
    nib.save(rescaled_data4, os.path.join(data_dir,'image_'+str(counter)))
    
    rescaled_label4 = nib.Nifti1Image(label4, msk.affine, msk.header)
    nib.save(rescaled_label4, os.path.join(data_dir,'label_'+str(counter)))
    counter += 1

In [None]:
import wget

url = 'https://b2drop.bsc.es/index.php/s/BIMCV-COVID19-Negative/download'
out_dir = '/home/marafath/scratch/bimcv/covid_neg'

wget.download(url, out=out_dir)

PET Mean

In [3]:
data_dir = '/home/marafath/scratch/pet/to_compute_canada'
count = 0
im_sum = 0
voxel_count = 0

for patient in sorted(os.listdir(data_dir)):
    if patient[0:4] == 'CHGJ' and count < 25:
        count += 1
        rpt = nib.load(os.path.join(data_dir,patient,patient+'_resized_pt.nii'))
        pt_img = rpt.get_fdata()
        
        im_sum += np.sum(pt_img)
        voxel_count += (pt_img.shape[0]*pt_img.shape[1]*pt_img.shape[2])
        
mean_val = im_sum/voxel_count
print(mean_val)  # mean = 0.27502204501836836
print(count)

0.27502204501836836
25


In [8]:
data_dir = '/home/marafath/scratch/pet/to_compute_canada'
count = 0
c2 = 0
im_sum = 0
voxel_count = 0

for patient in sorted(os.listdir(data_dir)):
    if patient[0:4] == 'CHGJ':
        if count >= 25:
            print('Hi')
            count += 1
            c2 += 1
        else:
            count += 1
            print('Not yet')
print(c2)

Not yet
Not yet
Not yet
Not yet
Not yet
Not yet
Not yet
Not yet
Not yet
Not yet
Not yet
Not yet
Not yet
Not yet
Not yet
Not yet
Not yet
Not yet
Not yet
Not yet
Not yet
Not yet
Not yet
Not yet
Not yet
Hi
Hi
Hi
Hi
Hi
Hi
Hi
Hi
Hi
Hi
10


PET STD

In [4]:
data_dir = '/home/marafath/scratch/pet/to_compute_canada'
count = 0
im_sum = 0
voxel_count = 0
mu = 0.27502204501836836

for patient in sorted(os.listdir(data_dir)):
    if patient[0:4] == 'CHGJ' and count < 25:
        count += 1
        rpt = nib.load(os.path.join(data_dir,patient,patient+'_resized_pt.nii'))
        pt_img = rpt.get_fdata()
        
        im_sum += np.sum(np.square(pt_img - mu))
        voxel_count += (pt_img.shape[0]*pt_img.shape[1]*pt_img.shape[2])
        
std_val = np.sqrt(im_sum/voxel_count)
print(std_val)  # std = 1.0551596380864339
print(count)

1.0551596380864339
25


Checking Max/Min range

In [6]:
data_dir = '/home/marafath/scratch/pet/to_compute_canada'
count = 0
im_sum = 0
voxel_count = 0
mu = 0.27502204501836836
std = 1.0551596380864339

Min = []
Max = []

for patient in sorted(os.listdir(data_dir)):
    if patient[0:4] == 'CHGJ' and count < 25:
        count += 1
        rpt = nib.load(os.path.join(data_dir,patient,patient+'_resized_pt.nii'))
        pt_img = rpt.get_fdata()
        
        im_ = pt_img - mu
        im_ = np.divide(im_,std)
        
        Min.append(np.min(im_))
        Max.append(np.max(im_))

print(np.min(Min)) # = -0.26064496318029157
print(np.max(Max)) # = 61.16492458933134

-0.26064496318029157
61.16492458933134


In [None]:
data_dir = '/home/marafath/scratch/pet/to_compute_canada'
label_dir = '/home/marafath/scratch/pet/to_compute_canada/noisy_500faces_test2'

train_labels = []
train_images = []
im_sum = 0
voxel_count = 0

for patient in os.listdir(label_dir):
    pat_id = patient[0:7]
    pt = nib.load(os.path.join(data_dir,pat_id,pat_id+'_resized_pt.nii'))
    pt_img = pt.get_fdata()

    im_sum += np.sum(pt_img)
    voxel_count += (pt_img.shape[0]*pt_img.shape[1]*pt_img.shape[2])
    
    
mean_val = im_sum/voxel_count
print(mean_val) # mean = 0.27502204501836836

Loading KiTS data

In [None]:
data_dir = '/home/marafath/scratch/kits/training'

c = 0
for patient in os.listdir(data_dir):
    c += 1
    
    ct = nib.load(os.path.join(data_dir,patient,'imaging.nii.gz'))
    ct_img = ct.get_fdata()
    np.save(os.path.join(data_dir,patient,'image'),ct_img)

    gt = nib.load(os.path.join(data_dir,patient,'segmentation.nii.gz'))
    gt = gt.get_fdata()
    gt[gt == 2] = 0
    np.save(os.path.join(data_dir,patient,'mask'),gt)
    
    print(os.path.join(data_dir,patient))
    
    if c > 10:
        break

Loading PET/CT Head-neck data

In [None]:
data_dir = '/home/marafath/scratch/pet/to_compute_canada'

counter = 0
flag = 0

for patient in os.listdir(data_dir):
    if patient[0:4] == 'CHGJ':
        ct = nib.load(os.path.join(data_dir,patient,patient+'_ct.nii.gz'))
        ct_img = ct.get_fdata()

        pet = nib.load(os.path.join(data_dir,patient,patient+'_pt.nii.gz'))
        pet_img = pet.get_fdata()

        gtv = nib.load(os.path.join(data_dir,patient,patient+'_ct_gtvt.nii.gz'))
        gtv_img = gtv.get_fdata()
        
        rpt = nib.load(os.path.join(data_dir,patient,patient+'_resized_pt.nii'))
        rpt_img = rpt.get_fdata()

        sl = 80
        plt.figure('check', (18, 6))
        plt.subplot(1, 4, 1)
        plt.title('CT')
        plt.imshow(ct_img[:, :, sl], cmap='gray')
        plt.subplot(1, 4, 2)
        plt.title('PET')
        plt.imshow(pet_img[:, :, sl], cmap='gray')
        plt.subplot(1, 4, 3)
        plt.title('GTV')
        plt.imshow(gtv_img[:, :, sl], cmap='gray')
        plt.subplot(1, 4, 4)
        plt.title('GTV')
        plt.imshow(rpt_img[:, :, sl], cmap='gray')
        print(ct_img.shape)
        print(pet_img.shape)
        print(gtv_img.shape)
        print(rpt_img.shape)
        print(np.max(rpt_img))
        print(np.min(rpt_img))
        plt.show()

In [None]:
nm = nib.load('/home/marafath/scratch/pet/to_compute_canada/noisy_500faces_test2/CHGJ007_noisy_500faces.nii')
nm = nm.get_fdata()

m = nib.load('/home/marafath/scratch/pet/to_compute_canada/CHGJ007/CHGJ007_ct_gtvt.nii.gz')
m = m.get_fdata()

s = nm.shape[-1]

for i in range(45, 60, 2):
    plt.figure('check', (18, 6))
    plt.subplot(1, 2, 1)
    plt.title('original')
    plt.imshow(m[:, :, i], cmap='gray')
    plt.subplot(1, 2, 2)
    plt.title('noisy')
    plt.imshow(nm[:, :, i], cmap='gray')
    plt.show()

In [None]:
import csv

filename = "/home/marafath/scratch/pet_ct/hecktor_train/bbox.csv"
rows = [] 

with open(filename, 'r') as csvfile: 
    # creating a csv reader object 
    csvreader = csv.reader(csvfile) 
  
    # extracting each data row one by one 
    for row in csvreader: 
        rows.append(row) 
  
    # get total number of rows 
    print("Total no. of rows: %d"%(csvreader.line_num))

print(rows)        

#### Iran Data: creating false mask and subtract mean

In [None]:
data_dir = '/home/marafath/scratch/iran_organized_data2'

counter = 0
flag = 0
for patient in os.listdir(data_dir):
    label = int(patient[-1])
    for series in os.listdir(os.path.join(data_dir,patient)):
        counter += 1

        img = nib.load(os.path.join(data_dir,patient,series,'image.nii.gz'))
        img = img.get_fdata()
        
        img_ = nib.load(os.path.join(data_dir,patient,series,'lung.nii.gz'))
        img_ = img_.get_fdata()
        
        '''
        img_ = img - 100 
        
        false_mask = np.zeros((img.shape), dtype=np.int16)
        
        img_ = nib.Nifti1Image(img_, np.eye(4))
        nib.save(img_, os.path.join(data_dir,patient,series,'image_ms.nii.gz')) 
        
        false_mask = nib.Nifti1Image(false_mask, np.eye(4))
        nib.save(false_mask, os.path.join(data_dir,patient,series,'segmentation.nii.gz'))  
        '''
        
        
        plt.figure('check', (18, 6))
        plt.subplot(1, 2, 1)
        plt.title('image ')
        plt.imshow(img[:, :, 50], cmap='gray')
        print(np.max(img), np.min(img))
        plt.subplot(1, 2, 2)
        plt.title('image ')
        plt.imshow(img_[:, :, 50], cmap='gray')
        print(np.max(img_), np.min(img_))
        plt.show()
        
        counter += 1
        
        if counter > 20:
            flag = 1
            break
    if flag == 1:
        break
        
        
#print(counter)

#### Reconfiguring the EU data masks and checking the data

In [None]:
data_dir = '/home/marafath/scratch/eu_data'

for case in os.listdir(data_dir):
    
    img = nib.load(os.path.join(data_dir,case,'image.nii.gz'))
    img = img.get_fdata()
    
    
    seg = nib.load(os.path.join(data_dir,case,'image_masked.nii.gz'))
    seg = seg.get_fdata()
    #if np.max(seg) == 6:
    #    label = 1
    #else:
    #    label = 0
    
    '''
    
    seg[seg == 2] = 1
    seg[seg == 3] = 1
    seg[seg == 4] = 1
    seg[seg == 5] = 1
    seg[seg == 6] = 2
  
    
    seg = nib.Nifti1Image(seg, np.eye(4))
    nib.save(seg, os.path.join(data_dir,case,'segmentation_l&i.nii.gz'))  
    '''  

    plt.figure('check', (18, 6))
    plt.subplot(1, 2, 1)
    plt.title('image ')
    plt.imshow(img[:, :, 30], cmap='gray')
    print(np.max(img), np.min(img))
    plt.subplot(1, 2, 2)
    plt.title('image ')
    plt.imshow(seg[:, :, 50], cmap='gray')
    print(np.max(seg), np.min(seg))
    plt.show()

#### Trying to use median filter to remove sparcity in infection

In [None]:
from scipy import ndimage
data_dir = '/home/marafath/scratch/eu_data'

for case in os.listdir(data_dir):
    seg = nib.load(os.path.join(data_dir,case,'segmentation_l&i.nii.gz'))
    seg = seg.get_fdata()
    seg_ = ndimage.median_filter(seg, size=5)
        
    seg_ = nib.Nifti1Image(seg_, np.eye(4))
    nib.save(seg_, os.path.join(data_dir,case,'segmentation_l&i.nii.gz'))  

    '''
    plt.figure('check', (18, 6))
    plt.subplot(1, 2, 1)
    plt.title('original')
    plt.imshow(seg[:, :, 45])
    print(seg.shape)
    plt.subplot(1, 2, 2)
    plt.title('Filtered')
    plt.imshow(seg_[:, :, 45])
    print(seg_.shape)
    plt.show()
    '''

#### Checking Radiopedia Data

In [None]:
data_dir = '/home/marafath/scratch/3d_seg_ct'

for case in os.listdir(data_dir): 
    img = nib.load(os.path.join(data_dir,case,'image.nii.gz'))
    img = img.get_fdata()
    seg = nib.load(os.path.join(data_dir,case,'label.nii.gz'))
    seg = seg.get_fdata()
    
    plt.figure('check', (18, 6))
    plt.subplot(1, 2, 1)
    plt.title('original image')
    plt.imshow(img[:, :, 45], cmap='gray')
    print(np.max(img), np.min(img))
    plt.subplot(1, 2, 2)
    plt.title('Mask')
    plt.imshow(seg[:, :, 45])
    print(np.max(seg), np.min(seg))
    plt.show()

#### Loading data from rediopedia and eu sources and random mixing test

In [None]:
import os
images = []
labels = []

data_dir = '/home/marafath/scratch/3d_seg_ct'
for case in os.listdir(data_dir): 
    images.append(os.path.join(data_dir,case,'image.nii.gz'))
    labels.append(os.path.join(data_dir,case,'label.nii.gz'))

data_dir = '/home/marafath/scratch/eu_data'
for case in os.listdir(data_dir):
    images.append(os.path.join(data_dir,case,'image.nii.gz'))
    labels.append(os.path.join(data_dir,case,'segmentation_l&i.nii.gz'))

In [None]:
np.random.seed(31)
idx = np.random.permutation(range(len(images)))

f = 4

for fold in range(0,f): 
    train_images = []
    train_labels = []
    val_images = []
    val_labels = []
    for i in range(0,len(idx)):
        if fold == 0:
            if i > 43:
                train_images.append(images[idx[i]])
                train_labels.append(labels[idx[i]])
            else:
                val_images.append(images[idx[i]])
                val_labels.append(labels[idx[i]])
        elif fold == 1:
            if i < 44 or i > 87:
                train_images.append(images[idx[i]])
                train_labels.append(labels[idx[i]])
            else:
                val_images.append(images[idx[i]])
                val_labels.append(labels[idx[i]])
        elif fold == 2:
            if i < 88 or i > 131:
                train_images.append(images[idx[i]])
                train_labels.append(labels[idx[i]])
            else:
                val_images.append(images[idx[i]])
                val_labels.append(labels[idx[i]])
        elif fold == 3:
            if i < 132:
                train_images.append(images[idx[i]])
                train_labels.append(labels[idx[i]])
            else:
                val_images.append(images[idx[i]])
                val_labels.append(labels[idx[i]])

#### Creating first 14 data segmentation

In [None]:
data_dir = '/home/marafath/scratch/iran_organized_data2'

images = []
labels = []
im_dir = []

i = 0
flag = 0
for patient in os.listdir(data_dir):
    for series in os.listdir(os.path.join(data_dir,patient)):
        images.append(os.path.join(data_dir,patient,series,'image.nii.gz'))
        labels.append(os.path.join(data_dir,patient,series,'segmentation.nii.gz'))
        im_dir.append(os.path.join(data_dir,patient,series))
    
        i += 1
    
        if i == 14:
            flag = 1
            break
    if flag == 1:
        break

In [None]:
train_imtrans = Compose([
    ScaleIntensity(),
    AddChannel(),
    CastToType(),
    SpatialPad((96, 96, 96), mode='constant'),
    ToTensor()
])
train_segtrans = Compose([
    AddChannel(),
    CastToType(),
    SpatialPad((96, 96, 96), mode='constant'),
    ToTensor()
])

In [None]:
train_ds = NiftiDataset(images, labels, transform=train_imtrans, seg_transform=train_segtrans)
train_loader = DataLoader(train_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=torch.cuda.is_available())

In [None]:
device = torch.device('cuda:0')
model = monai.networks.nets.UNet(
    dimensions=3,
    in_channels=1,
    out_channels=3,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
).to(device)

In [None]:
from monai.inferers import sliding_window_inference
model.load_state_dict(torch.load('/home/marafath/scratch/saved_models/rad_ir_ssl_it_50.pth'))
model.eval()
with torch.no_grad():
    metric_sum = 0.
    metric_count = 0
    i = 0
    for val_data in train_loader:
        if i < 1246:
            i += 1
            continue
        else:
            val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
            # define sliding window size and batch size for windows inference 
            roi_size = (160, 160, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
            pseudo_mask = torch.argmax(val_outputs, dim=1).detach().cpu().numpy() 

            #if labels[i] == 0: 
             #   pseudo_mask[pseudo_mask == 2] = 1

            im = val_images.cpu().detach().numpy()
            im = np.squeeze(im)

            seg = val_labels.cpu().detach().numpy()
            seg = np.squeeze(seg)

            pseudo_mask = np.squeeze(pseudo_mask)

            sl = 45
            plt.subplot(1, 3, 1)
            plt.title('image ' + str(i))
            plt.imshow(im[:, :, sl], cmap='gray')
            print(im.shape)
            plt.subplot(1, 3, 2)
            plt.title('label ' + str(i))
            plt.imshow(seg[:, :, sl])
            print(seg.shape)
            plt.subplot(1, 3, 3)
            plt.title('output ' + str(i))
            plt.imshow(pseudo_mask[:, :, sl])
            #print(labels[i])
            print(im_dir[i])
            #print(np.max(pseudo_mask))
            #plt.imshow(torch.argmax(val_outputs, dim=1).detach().cpu()[0, :, :, sl])
            plt.show()


            im = nib.Nifti1Image(im, np.eye(4))
            nib.save(im, os.path.join(im_dir[i],'image.nii.gz')) 

            pseudo_mask = nib.Nifti1Image(pseudo_mask, np.eye(4))
            nib.save(pseudo_mask, os.path.join(im_dir[i],'segmentation.nii.gz')) 

            i += 1

#### Observing the segmentation

In [None]:
data_dir = '/home/marafath/scratch/iran_organized_data2'

images = []
labels = []
im_dir = []

for patient in os.listdir(data_dir):
    for series in os.listdir(os.path.join(data_dir,patient)):
        images.append(os.path.join(data_dir,patient,series,'masked_image.nii.gz'))
        labels.append(os.path.join(data_dir,patient,series,'segmentation.nii.gz'))
        im_dir.append(os.path.join(data_dir,patient,series))

In [None]:
train_imtrans = Compose([
    ScaleIntensity(),
    AddChannel(),
    CastToType(),
    SpatialPad((96, 96, 96), mode='constant'),
    ToTensor()
])
train_segtrans = Compose([
    AddChannel(),
    CastToType(),
    SpatialPad((96, 96, 96), mode='constant'),
    ToTensor()
])

In [None]:
train_ds = NiftiDataset(images, labels, transform=train_imtrans, seg_transform=train_segtrans)
train_loader = DataLoader(train_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=torch.cuda.is_available())

In [None]:
i = 0
for val_data in train_loader:
    im = val_data[0]
    seg = val_data[1]

    im = im.cpu().detach().numpy()
    im = np.squeeze(im)

    seg = seg.cpu().detach().numpy()
    seg = np.squeeze(seg)
    
    plt.figure('check', (12, 4))
    plt.subplot(1, 4, 1)
    plt.title('image ' + str(i))
    plt.imshow(im[:, :, 25], cmap='gray')
    plt.subplot(1, 4, 2)
    print(np.min(im))
    plt.title('label ' + str(i))
    plt.imshow(seg[:, :, 25])
    plt.subplot(1, 4, 3)
    plt.title('image ' + str(i))
    plt.imshow(im[:, :, 50], cmap='gray')
    plt.subplot(1, 4, 4)
    plt.title('label ' + str(i))
    plt.imshow(seg[:, :, 50])
    print(np.max(seg))
    plt.show()
    
    i += 1
    
    if i > 20:
        break

In [None]:
i = 0
for val_data in train_loader:
    im = val_data[0]
    seg = val_data[1]

    im = im.cpu().detach().numpy()
    im = np.squeeze(im)

    seg = seg.cpu().detach().numpy()
    seg = np.squeeze(seg)
    
    if np.max(seg) == 2:
        label = 1
    else:
        label = 0
    
    plt.figure('check', (18, 6))
    plt.subplot(1, 2, 1)
    plt.title('image '+ str(label))
    plt.imshow(im[:, :, 30], cmap='gray')
    #print(np.max(img), np.min(img))
    plt.subplot(1, 2, 2)
    plt.title('image '+ str(label))
    plt.imshow(im[:, :, 50], cmap='gray')
    #print(np.max(seg), np.min(seg))
    plt.show()
    
    if i > 100:
        break

In [None]:
import shutil
shutil.rmtree(im_dir[313])
shutil.rmtree(im_dir[314])
shutil.rmtree(im_dir[316])
shutil.rmtree(im_dir[317])
shutil.rmtree(im_dir[318])
shutil.rmtree(im_dir[594])
shutil.rmtree(im_dir[595])
shutil.rmtree(im_dir[596])
shutil.rmtree(im_dir[597])
shutil.rmtree(im_dir[598])
shutil.rmtree(im_dir[599])
shutil.rmtree(im_dir[600])
shutil.rmtree(im_dir[601])
shutil.rmtree(im_dir[746])
shutil.rmtree(im_dir[747])
shutil.rmtree(im_dir[789])
shutil.rmtree(im_dir[1059])
shutil.rmtree(im_dir[1060])
shutil.rmtree(im_dir[1061])
shutil.rmtree(im_dir[1154])

In [None]:
import copy
i = 0
for val_data in train_loader:
    im = val_data[0]
    seg = val_data[1]

    im = im.cpu().detach().numpy()
    im = np.squeeze(im)

    seg = seg.cpu().detach().numpy()
    seg = np.squeeze(seg)

    '''
    lung = copy.deepcopy(seg)
    lung[lung > 0] = 1
    '''
    infection = copy.deepcopy(seg)
    infection[infection < 2] = 0
    infection[infection > 0] = 1
    
    masked_infection = np.multiply(im,infection)
    
    
    plt.figure('check', (12, 4))
    plt.subplot(1, 4, 1)
    plt.title('image ' + str(i))
    plt.imshow(im[:, :, 50], cmap='gray')
    plt.subplot(1, 4, 2)
    plt.title('seg ' + str(i))
    plt.imshow(seg[:, :, 50])
    plt.subplot(1, 4, 3)
    plt.title('mas_inf' + str(i))
    plt.imshow(masked_infection[:, :, 50])
    plt.subplot(1, 4, 4)
    plt.title('infection ' + str(i))
    plt.imshow(infection[:, :, 50])
    plt.show()
    
    
    '''
    masked_infection = nib.Nifti1Image(masked_infection, np.eye(4))
    nib.save(masked_infection, os.path.join(im_dir[i],'masked_infection.nii.gz')) 


    lung = nib.Nifti1Image(lung, np.eye(4))
    nib.save(lung, os.path.join(im_dir[i],'lung.nii.gz')) 
    
    infection = nib.Nifti1Image(infection, np.eye(4))
    nib.save(infection, os.path.join(im_dir[i],'infection.nii.gz'))
    '''
    
    i += 1

#### Counting data

In [None]:
data_dir = '/home/marafath/scratch/iran_organized_data2'

covid_pat = 0
non_covid_pat = 0

covid_series = 0
non_covid_series = 0

for patient in os.listdir(data_dir):
    if int(patient[-1]) == 1:
        covid_pat += 1
    else:
        non_covid_pat += 1
    for series in os.listdir(os.path.join(data_dir,patient)):
        if int(patient[-1]) == 1:
            covid_series += 1
        else:
            non_covid_series += 1

print('COVID patients {}'.format(covid_pat))
print('nonCOVID patient {}'.format(non_covid_pat))

print('COVID volumes {}'.format(covid_series))
print('nonCOVID volumes {}'.format(non_covid_series))

In [None]:
data_dir = '/home/marafath/scratch/iran_organized_data2'

covid_pat = 0
non_covid_pat = 0

images_p = []
labels_p = []
images_n = []
labels_n = []

for patient in os.listdir(data_dir):
    if int(patient[-1]) == 0 and non_covid_pat > 236:
        continue 
        
    if int(patient[-1]) == 1:
        covid_pat += 1
        for series in os.listdir(os.path.join(data_dir,patient)):
            labels_p.append(1)
            images_p.append(os.path.join(data_dir,patient,series,'masked_infection.nii.gz'))
    else:
        non_covid_pat += 1
        for series in os.listdir(os.path.join(data_dir,patient)):
            labels_n.append(0)
            images_n.append(os.path.join(data_dir,patient,series,'masked_infection.nii.gz'))

In [None]:
print(len(images_p))
print(len(labels_p))
print(np.sum(labels_p))
print('\n')
print(len(images_n))
print(len(labels_n))
print(np.sum(labels_n))

In [None]:
train_images = []
train_labels = []

val_images = []
val_labels = []

for i in range(0,len(images_p)):
    if i < 407:
        train_images.append(images_p[i])
        train_labels.append(labels_p[i])
    else:
        val_images.append(images_p[i])
        val_labels.append(labels_p[i])

for i in range(0,len(images_n)):
    if i < 405:
        train_images.append(images_n[i])
        train_labels.append(labels_n[i])
    else:
        val_images.append(images_n[i])
        val_labels.append(labels_n[i])

print(len(train_images))
print(len(val_images))

In [None]:
print(train_labels)

#### Copying masked data to new folder

In [None]:
import os
import numpy as np
import shutil

data_dir = '/home/marafath/scratch/iran_organized_data2'

for patient in os.listdir(data_dir):
    os.mkdir(os.path.join('/home/marafath/scratch/m_data', patient))
    for series in os.listdir(os.path.join(data_dir,patient)):
        os.mkdir(os.path.join('/home/marafath/scratch/m_data', patient, series))
        shutil.copyfile(os.path.join(data_dir,patient,series,'masked_infection.nii.gz'), os.path.join('/home/marafath/scratch/m_data', patient, series, 'masked_infection.nii.gz')) 
        shutil.copyfile(os.path.join(data_dir,patient,series,'lung.nii.gz'), os.path.join('/home/marafath/scratch/m_data', patient, series,'lung.nii.gz')) 

#### Producing -/+ preserved lung masked images for Iran Data

In [None]:
import copy
data_dir = '/home/marafath/scratch/iran_organized_data2'

counter = 0
flag = 0
for patient in os.listdir(data_dir):
    label = int(patient[-1])
    for series in os.listdir(os.path.join(data_dir,patient)):
        counter += 1
        
        img_ = nib.load(os.path.join(data_dir,patient,series,'masked_image.nii.gz'))
        img_ = img_.get_fdata()
        
        '''
        img = nib.load(os.path.join(data_dir,patient,series,'image.nii.gz'))
        img = img.get_fdata()
        img_to_fix = copy.deepcopy(img)
        max_val = np.max(img_) - np.min(img_)
        img = np.multiply(img, max_val)
        img = img + np.min(img_)
        
        seg = nib.load(os.path.join(data_dir,patient,series,'segmentation.nii.gz'))
        seg = seg.get_fdata()
        
        seg[seg == 2] = 1
        masked_image = np.multiply(img,seg)
        
        masked_image = nib.Nifti1Image(masked_image, np.eye(4))
        nib.save(masked_image, os.path.join(data_dir,patient,series,'masked_image.nii.gz')) 
        
        '''

        
        plt.figure('check', (18, 6))
        plt.subplot(1, 2, 1)
        plt.title('image ')
        plt.imshow(img_[:, :, 30], cmap='gray')
        print(np.max(img_), np.min(img_))
        plt.subplot(1, 2, 2)
        plt.title('masked ')
        plt.imshow(img_[:, :, 50], cmap='gray')
        print(np.max(img_), np.min(img_))
        plt.show()
        

#### A lot of things happening: *Largest Connected Component, *nonzero-bounding box,  *3D image resizing

In [None]:
from skimage.measure import label 
import scipy
import copy
import itertools

data_dir = '/home/marafath/scratch/iran_organized_data2'

counter = 0
flag = 0
for patient in os.listdir(data_dir):
    for series in os.listdir(os.path.join(data_dir,patient)):

        img_ = nib.load(os.path.join(data_dir,patient,series,'masked_image.nii.gz'))
        img_ = img_.get_fdata()
        
        img = nib.load(os.path.join(data_dir,patient,series,'lung.nii.gz'))
        img = img.get_fdata()
        img_ax = copy.deepcopy(img)
        img_sg = copy.deepcopy(img)
        
        # For axial slices
        img_ax = np.sum(img_ax, axis = 2)
        img_ax[img_ax > 0] = 1
        img_ax[50:175,100] = 1
        
        # Largest connected componet calculation
        labels = label(img_ax)
        assert( labels.max() != 0 ) # assume at least 1 CC
        largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1
        largestCC = scipy.signal.medfilt2d(np.array(largestCC, float), [1, 5])
        
        # nonzero-area bounding box calculation
        rows = np.any(largestCC, axis=1)
        cols = np.any(largestCC, axis=0)
        ymin, ymax = np.where(rows)[0][[0, -1]]
        xmin, xmax = np.where(cols)[0][[0, -1]]
        
        # For sagittal slices
        img_sg = np.sum(img_sg, axis = 0)
        img_sg[img_sg > 0] = 1
        
        labels = label(img_sg)
        assert( labels.max() != 0 ) # assume at least 1 CC
        largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1
        
        rows = np.any(largestCC, axis=1)
        cols = np.any(largestCC, axis=0)
        #ymin, ymax = np.where(rows)[0][[0, -1]]
        #xmin, xmax = np.where(cols)[0][[0, -1]]
        zmin, zmax = np.where(cols)[0][[0, -1]]
        
        data = img_[ymin:ymax+1:, xmin:xmax+1, zmin:zmax+1]
        
        # Resizing
        initial_size_x = data.shape[0]
        initial_size_y = data.shape[1]
        initial_size_z = data.shape[2]

        new_size_x = 128
        new_size_y = 128
        new_size_z = 128

        delta_x = initial_size_x / new_size_x
        delta_y = initial_size_y / new_size_y
        delta_z = initial_size_z / new_size_z

        new_data = np.zeros((new_size_x, new_size_y, new_size_z))

        for x, y, z in itertools.product(range(new_size_x),
                                         range(new_size_y),
                                         range(new_size_z)):
            new_data[x][y][z] = data[int(x * delta_x)][int(y * delta_y)][int(z * delta_z)]

        
        '''
        plt.figure('check', (18, 6))
        plt.subplot(1, 2, 1)
        plt.title('image ')
        plt.imshow(new_data[:,:,50], cmap='gray')
        plt.subplot(1, 2, 2)
        plt.title('image ')
        plt.imshow(new_data[:,:,90], cmap='gray')
        print(new_data.shape)
        print(np.min(new_data), np.max(new_data))
        plt.show()
        '''
        
        new_data = nib.Nifti1Image(new_data, np.eye(4))
        nib.save(new_data, os.path.join(data_dir,patient,series,'cropped_and_resized_image.nii.gz')) 

In [None]:
import SimpleITK as sitk

#### Inspecting the cropped and resized data

In [None]:
import cv2
import matplotlib
import scipy
data_dir = '/home/marafath/scratch/iran_organized_data2'

i = 0
f = 0
for patient in os.listdir(data_dir):
    label = int(patient[-1])
    for series in os.listdir(os.path.join(data_dir,patient)):

        img = nib.load(os.path.join(data_dir,patient,series,'cropped_and_resized_image.nii.gz'))
        img = img.get_fdata()
        
        img = np.transpose(img, (2, 0, 1))
        #img = np.transpose(img, (2, 1, 0))
        img = np.sum(img, axis = 2)
        img = img - np.min(img)
        img = img/np.max(img)
        
        img3 = np.zeros((128,128,3))
        img3[:,:,0] = img
        img3[:,:,1] = img
        img3[:,:,2] = img
        
        
        plt.figure('check', (18, 6))
        plt.title('image '+str(label))
        plt.imshow(img3, cmap='gray')
        print(np.min(img3), np.max(img3))
        plt.show()
        
        
        #filename = os.path.join(data_dir,patient,series,'chest_projection.png')
        #matplotlib.image.imsave(filename, img3)
        #cv2.imwrite(filename, img3) 

#### Inspecting PNG Images

In [None]:
import cv2
data_dir = '/home/marafath/scratch/iran_organized_data2'

for patient in os.listdir(data_dir):
    label = int(patient[-1])
    for series in os.listdir(os.path.join(data_dir,patient)):
        img = cv2.imread(os.path.join(data_dir,patient,series,'chest_projection.png')) 
        
        #plt.title('image '+str(label))
        c = plt.imshow(img, cmap ='gray')
        #plt.colorbar(c)
        #print(np.min(img), np.max(img))
        #print(img.shape)
        plt.show()

#### Move projection data to new folders for running 2D classification

In [None]:
import shutil
data_dir = '/home/marafath/scratch/iran_organized_data2'

covid_pat = 0
non_covid_pat = 0

images_p = []
labels_p = []
images_n = []
labels_n = []

for patient in os.listdir(data_dir):
    if int(patient[-1]) == 0 and non_covid_pat > 236:
        continue 

    if int(patient[-1]) == 1:
        covid_pat += 1
        for series in os.listdir(os.path.join(data_dir,patient)):
            labels_p.append(1)
            images_p.append(os.path.join(data_dir,patient,series,'chest_projection.png'))
    else:
        non_covid_pat += 1
        for series in os.listdir(os.path.join(data_dir,patient)):
            labels_n.append(0)
            images_n.append(os.path.join(data_dir,patient,series,'chest_projection.png'))

for i in range(0,len(images_p)):
    if i < 407:
        shutil.move(images_p[i], os.path.join('/home/marafath/scratch/projection_2d/train/covid/'+str(i)+'.png'))
    else:
        shutil.move(images_p[i], os.path.join('/home/marafath/scratch/projection_2d/val/covid/'+str(i)+'.png'))

for i in range(0,len(images_n)):
    if i < 405:
        shutil.move(images_n[i], os.path.join('/home/marafath/scratch/projection_2d/train/noncovid/'+str(i)+'.png'))
    else:
        shutil.move(images_n[i], os.path.join('/home/marafath/scratch/projection_2d/val/noncovid/'+str(i)+'.png'))

In [None]:
import sys
import logging
import numpy as np
import nibabel as nib
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import monai
from monai.data import NiftiDataset
from monai.transforms import Compose, CastToType, SpatialPad, AddChannel, ScaleIntensity, Resize, RandRotate90, RandZoom, ToTensor
import os
import matplotlib.pyplot as plt

#### Gerating cropped and resized infections

In [None]:
from skimage.measure import label 
import scipy
import copy
import itertools

data_dir = '/home/marafath/scratch/iran_organized_data2'

counter = 0
flag = 0
for patient in os.listdir(data_dir):
    for series in os.listdir(os.path.join(data_dir,patient)):

        img_ = nib.load(os.path.join(data_dir,patient,series,'masked_infection.nii.gz'))
        img_ = img_.get_fdata()
        
        img = nib.load(os.path.join(data_dir,patient,series,'lung.nii.gz'))
        img = img.get_fdata()
        img_ax = copy.deepcopy(img)
        img_sg = copy.deepcopy(img)
        
        # For axial slices
        img_ax = np.sum(img_ax, axis = 2)
        img_ax[img_ax > 0] = 1
        img_ax[50:175,100] = 1
        
        # Largest connected componet calculation
        labels = label(img_ax)
        assert( labels.max() != 0 ) # assume at least 1 CC
        largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1
        largestCC = scipy.signal.medfilt2d(np.array(largestCC, float), [1, 5])
        
        # nonzero-area bounding box calculation
        rows = np.any(largestCC, axis=1)
        cols = np.any(largestCC, axis=0)
        ymin, ymax = np.where(rows)[0][[0, -1]]
        xmin, xmax = np.where(cols)[0][[0, -1]]
        
        # For sagittal slices
        img_sg = np.sum(img_sg, axis = 0)
        img_sg[img_sg > 0] = 1
        
        labels = label(img_sg)
        assert( labels.max() != 0 ) # assume at least 1 CC
        largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1
        
        rows = np.any(largestCC, axis=1)
        cols = np.any(largestCC, axis=0)
        #ymin, ymax = np.where(rows)[0][[0, -1]]
        #xmin, xmax = np.where(cols)[0][[0, -1]]
        zmin, zmax = np.where(cols)[0][[0, -1]]
        
        data = img_[ymin:ymax+1:, xmin:xmax+1, zmin:zmax+1]
        
        # Resizing
        initial_size_x = data.shape[0]
        initial_size_y = data.shape[1]
        initial_size_z = data.shape[2]

        new_size_x = 128
        new_size_y = 128
        new_size_z = 128

        delta_x = initial_size_x / new_size_x
        delta_y = initial_size_y / new_size_y
        delta_z = initial_size_z / new_size_z

        new_data = np.zeros((new_size_x, new_size_y, new_size_z))

        for x, y, z in itertools.product(range(new_size_x),
                                         range(new_size_y),
                                         range(new_size_z)):
            new_data[x][y][z] = data[int(x * delta_x)][int(y * delta_y)][int(z * delta_z)]

        
        '''
        plt.figure('check', (18, 6))
        plt.subplot(1, 2, 1)
        plt.title('image ')
        plt.imshow(new_data[:,:,50], cmap='gray')
        plt.subplot(1, 2, 2)
        plt.title('image ')
        plt.imshow(new_data[:,:,90], cmap='gray')
        print(new_data.shape)
        print(np.min(new_data), np.max(new_data))
        plt.show()
        '''
        
        new_data = nib.Nifti1Image(new_data, np.eye(4))
        nib.save(new_data, os.path.join(data_dir,patient,series,'cropped_and_resized_infection.nii.gz')) 

In [None]:
import cv2
import matplotlib
import scipy
data_dir = '/home/marafath/scratch/iran_organized_data2'

i = 0
f = 0
for patient in os.listdir(data_dir):
    label = int(patient[-1])
    for series in os.listdir(os.path.join(data_dir,patient)):

        img = nib.load(os.path.join(data_dir,patient,series,'cropped_and_resized_infection.nii.gz'))
        img = img.get_fdata()
        
        img = np.transpose(img, (2, 0, 1))
        #img = np.transpose(img, (2, 1, 0))
        img = np.sum(img, axis = 2)
        img = img - np.min(img)
        img = img/np.max(img)
        
        img3 = np.zeros((128,128,3))
        img3[:,:,0] = img
        img3[:,:,1] = img
        img3[:,:,2] = img
        
        
        plt.figure('check', (18, 6))
        plt.title('image '+str(label))
        plt.imshow(img3, cmap='gray')
        print(np.min(img3), np.max(img3))
        plt.show()
        

In [None]:
import SimpleITK as sitk