In [1]:
%%capture
!pip install torchmetrics

In [2]:
!mkdir utils data data/gt data/nuc data/act 

In [None]:
%%capture

'''
Folder structure:
    lucyd.py
    evaluate.py
    train.py

    utils/
        loader.py
        ssim.py

    data/
        gt/
        nuc/
        act/
'''

In [1]:
TRAIN_DEPTH=32
BATCH_SIZE=16

In [2]:
from utils.loader import *

nuc_data, gt_data = read_data('nuc')
act_data, _ = read_data('act')

In [3]:
print(nuc_data.shape)
print(act_data.shape)
print(gt_data.shape)

torch.Size([5, 128, 128, 128])
torch.Size([5, 128, 128, 128])
torch.Size([5, 128, 128, 128])


In [4]:
from torch.utils.data import DataLoader

nuc_data_train = nuc_data[:4]
act_data_train = act_data[:4]
gt_data_train = gt_data[:4]

nuc_data_test = nuc_data[4:]
act_data_test = act_data[4:]
gt_data_test = gt_data[4:]


nuc_train_loader = ImageLoader(gt_data_train, nuc_data_train, depth=TRAIN_DEPTH)
nuc_train_dataloader = DataLoader(nuc_train_loader, batch_size=BATCH_SIZE, shuffle=True)

nuc_test_loader = ImageLoader(gt_data_test, nuc_data_test, depth=TRAIN_DEPTH)
nuc_test_dataloader = DataLoader(nuc_test_loader, batch_size=BATCH_SIZE, shuffle=True)


act_train_loader = ImageLoader(gt_data_train, act_data_train, depth=TRAIN_DEPTH)
act_train_dataloader = DataLoader(act_train_loader, batch_size=BATCH_SIZE, shuffle=True)

act_test_loader = ImageLoader(gt_data_test, act_data_test, depth=TRAIN_DEPTH)
act_test_dataloader = DataLoader(act_test_loader, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
from lucyd import LUCYD, device
from train import train

model_nuc = LUCYD(num_res=1).to(device)
print('number of parameters: {}'.format(sum(p.numel() for p in model_nuc.parameters())))

model_nuc = train(model_nuc, nuc_train_dataloader, nuc_test_dataloader)

number of parameters: 24964
 -- Staring training epoch 1 --
train loss: 0.5482646045864796
validation loss: 0.6721225532989802
validation ssim: 0.05304 +- 0.00281
validation psnr: -6.51668 +- 0.5558
 -- Staring training epoch 2 --
train loss: 0.3571650087006273
 -- Staring training epoch 3 --
train loss: 0.29906612232414265
 -- Staring training epoch 4 --
train loss: 0.24179293673085733
 -- Staring training epoch 5 --
train loss: 0.20026325412173557
 -- Staring training epoch 6 --
train loss: 0.16846546592572198
validation loss: 0.1747958691590336
validation ssim: 0.69965 +- 0.0129
validation psnr: 24.96316 +- 0.6758
 -- Staring training epoch 7 --
train loss: 0.14836960357901235
 -- Staring training epoch 8 --
train loss: 0.13753644979326968
 -- Staring training epoch 9 --
train loss: 0.1280154774198151
 -- Staring training epoch 10 --
train loss: 0.116915715253397
 -- Staring training epoch 11 --
train loss: 0.10982110072928715
validation loss: 0.1193968606574896
validation ssim: 0.7

In [None]:
from evaluate import evaluate

print('Nuc model, nuc data evaluation')
evaluate(model_nuc, nuc_test_dataloader)
print('Nuc model, act data evaluation')
evaluate(model_nuc, act_test_dataloader)

Nuc model, nuc data evaluation
testing ssim: 0.95255 +- 0.00555
testing psnr: 28.57226 +- 0.9215
Nuc model, act data evaluation
testing ssim: 0.90248 +- 0.00354
testing psnr: 24.81751 +- 0.34687


In [None]:
# ==============================================================================

In [5]:
from lucyd import LUCYD, device
from train import train

model_act = LUCYD(num_res=1).to(device)
print('number of parameters: {}'.format(sum(p.numel() for p in model_act.parameters())))

model_act = train(model_act, act_train_dataloader, act_test_dataloader)

number of parameters: 24964
 -- Staring training epoch 1 --
train loss: 0.6762025594732128
testing loss: 0.7057296874036862
testing ssim: 0.02007 +- 0.00079
testing psnr: -13.34359 +- 0.15934
 -- Staring training epoch 2 --
train loss: 0.3066966314090279
 -- Staring training epoch 3 --
train loss: 0.254946117451823
 -- Staring training epoch 4 --
train loss: 0.22145816370639224
 -- Staring training epoch 5 --
train loss: 0.18615566621501148
 -- Staring training epoch 6 --
train loss: 0.17215503493733408
testing loss: 0.19417654480857674
testing ssim: 0.67191 +- 0.01191
testing psnr: 24.87618 +- 0.51892
 -- Staring training epoch 7 --
train loss: 0.16165327631375903
 -- Staring training epoch 8 --
train loss: 0.15363820399785877
 -- Staring training epoch 9 --
train loss: 0.13920431585733684
 -- Staring training epoch 10 --
train loss: 0.13342115494869922
 -- Staring training epoch 11 --
train loss: 0.12960585815531742
testing loss: 0.13248694461515745
testing ssim: 0.77492 +- 0.01209
t

In [7]:
from evaluate import evaluate

print('Act model, nuc data evaluation')
evaluate(model_act, nuc_test_dataloader)
print('Act model, act data evaluation')
evaluate(model_act, act_test_dataloader)

Act model, nuc data evaluation
testing ssim: 0.93357 +- 0.00199
testing psnr: 27.62928 +- 0.32335
Act model, act data evaluation
testing ssim: 0.94501 +- 0.0026
testing psnr: 27.83452 +- 0.51546


In [8]:
torch.save(model_nuc.state_dict(), 'lucyd-nuc.pth')
torch.save(model_act.state_dict(), 'lucyd-act.pth')