In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
import astra
astra.set_gpu_index((0, 1, 2, 3))

from AdversarialRegularizer import AdversarialRegulariser
from networks import AlexNet_3D
from networks_FoV import ResNetL2_local
from ellipsgen.generate_data_function import get_batch
import numpy as np
import ellipsgen.CCB_class as CT
import ellipsgen.phantom_class as ph
import matplotlib.pyplot as plt
import time

BATCH_SIZE = 4
LOOPS = 2
IMG_SIZE = 128
STEPS = 100

In [2]:
import tensorflow as tf
tf.test.gpu_device_name()

'/device:GPU:0'

In [None]:
# Parameter choices. Heuristic in the BWGAN paper: Choose GAMMA as average dual norm of clean image
# LMB should be bigger than product of norm times dual norm.

# For s=0.0, this implies GAMMA =1.0
# For s=1.0, have GAMMA = 10.0 as realisitc value
S = 0.0
LMB = 10.0
GAMMA = 1.0
CUTOFF = 20.0

saves_path = '/export/scratch1/home/voxels-gpu0/codesprint_learned_prior/AdvRegSaves/test2/ell_noisy_' + str(IMG_SIZE)
regularizer = AdversarialRegulariser(saves_path, IMAGE_SIZE=(None, IMG_SIZE, IMG_SIZE, IMG_SIZE, 1), NETWORK=AlexNet_3D,
                                     s=S, cutoff=CUTOFF, lmb=LMB, gamma=GAMMA)



In [3]:
def load_pair(i):
    prefix = '../noisy_' + str(IMG_SIZE) + '/'
    fdk = np.load(prefix + 'FDK/dataset' + str(i) + '.npy')
    gt = np.load(prefix + 'GT/dataset' + str(i) + '.npy')
    
    return fdk, gt


def generate_pairs(amount=1):
    num_angles = 360
    noise = ['Poisson', 2 ** 14] # Do not go below 2 ** 8, lower number in the 2nd argument means more noise
    num_vox = IMG_SIZE
    
    voxels = [num_vox, num_vox, num_vox]
    src_rad = 10
    det_rad = 0

    fdks = np.empty((amount, *voxels))
    gts = np.empty_like(fdks)

    for i in range(amount):
        data_obj = ph.phantom(voxels, '22 Ellipses', num_angles, noise, src_rad, det_rad)
        gts[i, ...] =  data_obj.f

        case = CT.CCB_CT(data_obj)
        fdks[i, ...] = case.do_FDK()
        
        if i == 0:
            plt.figure()
            plt.imshow(gts[i, int(IMG_SIZE/2), ...])
            plt.show()
            
            plt.figure()
            plt.imshow(fdks[i, int(IMG_SIZE/2), ...])
            plt.show()
            
    
    return (gts, fdks)
    

def evaluate(gt, adv):
    # gt, adv = get_batch(eval_data=True, noise_levels=['01', '016'], methods=['EM', 'SGD'])
    regularizer.test(groundTruth=gt, adversarial=adv)

    
def train(steps):
    for k in range(steps):        
#         if k%50 == 0:
#             gt, adv = generate_pairs(amount=1) # batch, x, y, z, channel
#             evaluate(gt, adv)
#         else:
            gts, advs = load_pair(k) # batch, x, y, z, channel
    
            print(k)
            t = time.time()
            regularizer.train(groundTruth=gts, adversarial=advs, learning_rate=LEARNING_RATE)
            print(time.time() - t)
            
    regularizer.save()

In [None]:
LEARNING_RATE = 0.00005
for k in range(LOOPS):
    train(STEPS)

LEARNING_RATE = 0.00002
for k in range(LOOPS):
    train(STEPS)

0
39.60244154930115
1
0.7949061393737793
2
0.7970523834228516
3
0.7945342063903809
4
0.7928497791290283
5
0.7941534519195557
6
0.7982933521270752
7
0.7849075794219971
8
0.7903594970703125
9
0.7923219203948975
10
0.7944197654724121
11
0.796797513961792
12
0.7938878536224365
13
0.7933619022369385
14
0.790978193283081
15
0.7870767116546631
16
0.7925925254821777
17
0.7885901927947998
18
0.7891812324523926
19
0.7928667068481445
20
0.8004090785980225
21
0.7996535301208496
22
0.7995331287384033
23
0.8010807037353516
24
0.803351879119873
25
0.804323673248291
26
0.8056373596191406
27
0.8030939102172852
28
0.8037822246551514
29
0.800816535949707
30
0.803708553314209
31
0.803783655166626
32
0.8017001152038574
33
0.7944247722625732
34
0.7971842288970947
35
0.7987008094787598
36
0.8013761043548584
37
0.8015506267547607
38
0.8007853031158447
39
0.8022935390472412
40
0.8008809089660645
41
0.7989768981933594
42
0.8005707263946533
43
0.8003654479980469
44
0.8045811653137207
45
0.8022308349609375
46
0.7