In [None]:
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import matplotlib.pyplot as plt

from cc_hardware.algos.datasets import HistogramDataset

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

## Loading a Dataset

In [None]:
dataset_path = "[PATH TO YOUR DATASET]"
dataset = HistogramDataset(
    dataset_path
)

In [None]:
# Compute mean over the first dimension (n)
histograms = dataset.inputs[0]  # Shape becomes [4, 4, 48]

# Create a 4x4 grid of bar plots
fig, axes = plt.subplots(4, 4, figsize=(12, 12))

x = np.arange(16)  # x-axis positions for the 48 bars

for i in range(4):
    for j in range(4):
        axes[i, j].bar(x, histograms[i, j].numpy(), color='b', alpha=0.75)
        axes[i, j].set_title(f"Pixel ({i},{j})")
        axes[i, j].set_xlabel("Bin Number")
        axes[i, j].set_ylabel("Num Photons")

plt.tight_layout()
plt.show()

In [None]:
print(f'input shape: {dataset.inputs.shape}')
print(f'target shape: {dataset.targets.shape}')

In [None]:
# Define the sizes for training, validation, and test sets
train_size = int(0.5 * len(dataset))
val_size = int(0.25 * len(dataset))
test_size = len(dataset) - train_size - val_size

# Split the dataset into training, validation, and test sets
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size],
                                                        generator=torch.Generator().manual_seed(1))

batch_size = 32

# Create DataLoaders for each set
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

## Training

In [None]:
from cc_hardware.algos.models import DeepLocation8, initialize_weights

model = DeepLocation8()
model.to(device)

In [None]:
model.apply(initialize_weights)

In [None]:
# Use MSELoss for euclidean distance to true location
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [None]:
def train(dataloader, model, loss_fn, optimizer, clipping=False, debug=False):
    size = len(dataloader.dataset)
    model.train()
    train_loss = 0
    for batch, (X, y) in enumerate(dataloader):
        if len(X) < batch_size:
            continue

        X, y = X.to(device), y.to(device)
        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)
        train_loss += loss.item()

        # Backpropagation
        loss.backward()

        if clipping:
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Adjust max_norm as needed
        
        if debug:
            # Inspect gradients for each layer
            for name, param in model.named_parameters():
                if param.grad is not None:  # Only check if gradient is computed
                    print(f"Layer: {name} | Gradient mean: {param.grad.abs().mean().item()} | Gradient max: {param.grad.abs().max().item()}")
                else:
                    print(f"Layer: {name} has no gradient.")

        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
    train_loss /= len(dataloader)
    return train_loss

In [None]:
def test(dataloader, model, loss_fn):
    num_batches = len(dataloader)
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
    test_loss /= num_batches
    print(f"Test Error: \n Avg loss: {test_loss:>8f} \n")
    return test_loss

In [None]:
def train_early_stopping(train_loader, val_loader, model, loss_fn, optimizer, 
    epochs=50, early_stopping=True, patience=5, threshold=0.15, clipping=False, debug=False):
    train_losses = []
    val_losses = []
    best_val_loss = float('inf')
    best_model_state = None
    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        train_loss = train(train_loader, model, loss_fn, optimizer, clipping=clipping, debug=debug)
        val_loss = test(val_loader, model, loss_fn)
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        if early_stopping:
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_model_state = model.state_dict()
                patience_counter = 0
            else:
                if val_loss / best_val_loss > 1 + threshold:
                    patience_counter += 1
                    if patience_counter >= patience:
                        print(f"Early stopping at epoch {t+1}")
                        break

        best_model = model.__class__().to(device)
        best_model.load_state_dict(best_model_state)
    return best_model, train_losses, val_losses

In [None]:
best_model, train_losses, val_losses = train_early_stopping(
    train_loader, val_loader, model, loss_fn, optimizer, epochs=10, early_stopping=True,
    threshold=0.4, patience=10,
    clipping=True)

# plot training and validation losses
plt.figure()
plt.plot(train_losses, label='train')
plt.plot(val_losses, label='val')
plt.legend()
plt.show()


model = best_model

### Saving and Loading Models

In [None]:
model_save_path = 'outputs/example_model.mdl'

In [None]:
torch.save(model.state_dict(), model_save_path)

In [None]:
model = DeepLocation8().to(device)
model.load_state_dict(torch.load(model_save_path))

## Evaluation

In [None]:
model.eval()
test(test_loader, model, loss_fn)

In [None]:
for i in range(20):
    # Get a single example from the test dataset
    example_data, example_label = test_dataset[i]
    example_label = example_label.to(device)

    # Move the example data to the appropriate device
    example_data = example_data.unsqueeze(0).to(device)

    # Set the model to evaluation mode
    model.eval()

    # Get the model's prediction
    with torch.no_grad():
        example_data = example_data.to(device)
        output = model(example_data).squeeze()

    print(f'Prediction: {output}, Actual label: {example_label}, Distance: {torch.norm(output - example_label):.4f}')