In [1]:
import torch
from torch import nn
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.tensorboard import SummaryWriter
import structured_light_tomography.models as models
import structured_light_tomography.training as training
import structured_light_tomography.dataset_generation as dg
import structured_light_tomography.photocount_treatment as pt
from torchvision.transforms import v2
import numpy as np
import matplotlib.pyplot as plts
from os.path import join
import torchvision
import torch.nn.functional as F
import h5py
import matplotlib.pyplot as plt
from torch.utils.data import random_split

device = ("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

Using cuda device


In [2]:
order = 3

In [None]:
with h5py.File('TrainingData/pure.h5', 'r') as f:
    x = f[f'x_order{order}'][:]
    y = f[f'y_order{order}'][:]

photocounts = 2048
dg.sample_photons(x,photocounts)
dset = TensorDataset(torch.from_numpy(x), torch.from_numpy(y))

train_size = int(0.8 * len(dset))  # 80% for training
test_size = len(dset) - train_size  # 20% for testing

train_dataset, test_dataset = random_split(dset, [train_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256)

In [None]:
L = x.shape[2]
loss_fn = training.fidelity_loss
n_channels = 2
n_classes = 2*(order+1)

model = models.ConvNet(L,L,n_channels, n_classes,[24,40,35],5,nn.ELU,[120,80,40]).to(device)

optimizer = torch.optim.Adam(model.parameters(), amsgrad=True)

In [None]:
save_path = f"runs/Pure/Order1/{photocounts}Photocounts"

writer = SummaryWriter(save_path)
early_stopping = training.EarlyStopping(patience=50,save_path=save_path)
for t in range(200):
    epoch = t+1
    print(f"-------------------------------\nEpoch {epoch}")
    training.train(model, train_loader, loss_fn, optimizer, device)
    val_loss = training.test(model, test_loader, loss_fn, device, epoch, writer, verbose=True)
    early_stopping(val_loss, model)

    if early_stopping.early_stop:
        print("Early stopping")
        break
print("Done!")
writer.close()

In [3]:
"""with h5py.File('ExperimentalData/Intense/pure.h5', 'r') as f:
    x_exp = np.float32(f[f'images_order{order}'][:])
    y_exp = f[f'coefficients_order{order}'][:]"""

with h5py.File('ExperimentalData/Photocount/datasets.h5', 'r') as f:
    histories = f[f'histories_order{order}'][:]
    x_exp = np.float32(np.array([pt.array_representation(history,(2,64,64)) for history in histories]))
    y_exp = dg.real_representation(f[f'coefficients_order{order}'][:])

In [5]:
mean = [x_exp[:, n, :, :].mean() for n in range(x_exp.shape[1])]
std = [x_exp[:, n, :, :].std() for n in range(x_exp.shape[1])]

X = v2.Compose([
        torch.from_numpy,
        v2.Normalize(mean=mean, std=std),
        v2.Resize((64, 64)),
    ])(x_exp).to(device)
Y = torch.from_numpy(y_exp).to(device)



In [6]:
Y.shape

torch.Size([50, 8])

In [8]:
model = torch.load("runs/Pure/Order1/2048Photocounts/checkpoint.pt")
model.eval()

from structured_light_tomography.training import fidelity
fidelity(model(X),Y).mean()

tensor(0.8079, device='cuda:0', grad_fn=<MeanBackward0>)