In [1]:
"""Libraries"""
from tqdm import tqdm
import torch
import torch.nn as nn
from models.DDPM import Unet
from utils.dataset import import_dataset
import matplotlib.pyplot as plt

black_level=99.6
gain=2.2
readout_noise=3.0
white_level=2**16-1

def electron_repr(image_data):
    return (image_data - black_level)/ gain

def digital_repr(image_data):
    return (image_data * gain + black_level)

def noise_distr(image_data: torch.Tensor, factor: float):
    if not 0 <= factor < 1:
        raise ValueError("factor must be between 0 and 1")

    electron_repr(image_data)
    scaled_data = factor * image_data
    noise_var = (1 - factor) * torch.clip(scaled_data, 0, None) + \
                (1 - factor ** 2) * (electron_repr(readout_noise / gain)) ** 2
    return digital_repr(torch.normal(0, torch.sqrt(noise_var)))

def energy_norm(x: torch.Tensor):
    mu = 488.0386
    sigma = 3.5994
    return (x-mu)/sigma

def energy_denorm(x: torch.Tensor):
    mu = 488.0386
    sigma = 3.5994
    return torch.clip((x*sigma+mu),0,self.white_level)

In [2]:
"""Dataset"""
name = 'ls_ae'
mode = name.split('_')[-1]

train_loader, valid_loader = import_dataset(data_name=name, 
                                            batch_size=32, 
                                            image_size = 256,
                                            force_download = False)

Loading Dataset
Light Sheet data imported!


In [3]:
"""Model"""
model = Unet(dim=32,out_dim=1,dim_mults=(1, 2, 4, 8),channels=1)
lr=0.00002

model.cuda()

criterion=nn.MSELoss()
optimizer=torch.optim.Adam(model.parameters(), lr=lr)

In [6]:
"""Training"""
for epoch in range(10): 

    running_loss = 0.0
    for i, data in tqdm(enumerate(train_loader, 0)):
        _, labels = data
        labels=labels.to('cuda:0')
        noise = torch.zeros_like(labels)
        time=[]
        for i,l in enumerate(labels):
#             time.append(int(torch.rand(1).item()*1001))
            time.append(500)
            noise[i] = noise_distr(l, torch.linspace(0,1000,1001)[time[-1]]/1001)

        optimizer.zero_grad()
        
        labels = energy_norm(labels)
        noise = energy_norm(noise)

        outputs = model(labels,torch.tensor(time).to('cuda:0'))
        loss = criterion(outputs, noise)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        
    print(f'[{epoch + 1}] loss: {running_loss:.3f}')
    running_loss = 0.0

print('Finished Training')

56it [00:37,  1.49it/s]


[1] loss: 654637.562


56it [00:38,  1.47it/s]


[2] loss: 654291.133


56it [00:38,  1.46it/s]


[3] loss: 653945.077


56it [00:38,  1.46it/s]


[4] loss: 653573.440


56it [00:38,  1.47it/s]


[5] loss: 653261.459


56it [00:38,  1.47it/s]


[6] loss: 652926.775


56it [00:38,  1.46it/s]


[7] loss: 652587.662


56it [00:38,  1.46it/s]


[8] loss: 652196.117


56it [00:38,  1.45it/s]


[9] loss: 651869.792


56it [00:38,  1.46it/s]

[10] loss: 651465.348
Finished Training





In [None]:
"""Test"""
x,y = next(iter(valid_loader))
img = y[0]
img = energy_norm(img)

plt.figure()
plt.imshow(noise_distr(img,0.5))
plt.figure()
plt.imshow(model(img))