### Demo for how to consume OADAT with pytorch

In [1]:
import os, sys, glob
import h5py
import numpy as np
import matplotlib.pyplot as plt

if os.path.basename(os.getcwd()) == 'notebooks':
    sys.path.append('..')
    os.chdir('..')

import dataset

### Define base data reader obj for OADAT

In [None]:
oadat_dir = '/mydata/dlbirhoui/firat/OADAT' ## switch with path to OADAT parent directory.
fname_SWFDsc = os.path.join(oadat_dir, 'SWFD_semicircle_RawBP.h5')
with h5py.File(fname_SWFDsc, 'r') as f:
    print(f.keys())
    pIDs = f['patientID'][()]
    unique_pIDs = np.unique(pIDs)
    pID_counts = {pID: np.sum(pIDs == pID) for pID in unique_pIDs}
    print(f'Unique patient IDs (and counts): {pID_counts}')
    num_images = f['sc_BP'].shape[0]

key = 'sc_BP' #semi-circle back-projection
prng = np.random.RandomState(42)
scaleclip_fn = lambda x: np.clip(x/np.max(x), a_min=-0.2, a_max=None) ## a standard scaling function for the optoacoustic iamge data.
dataset_obj = dataset.Dataset(fname_h5=fname_SWFDsc, key=key, transforms=scaleclip_fn, inds=None, shuffle=True, prng=prng)


### pytorch native data reader example

In [None]:
from torch.utils.data import DataLoader

train_set = dataset_obj
num_workers = 4
train_loader = DataLoader(train_set, batch_size=8, shuffle=True, num_workers=num_workers, drop_last=True)

kill_iter = 5
for i, x in enumerate(train_loader):
    print(f'Batch {i+1}: {x.shape}')
    if i >= kill_iter:
        break

### pytorch training loop
---
### Please use lightning.pytorch.Trainer module and add everything else as callbacks  
### Also never do training on a notebook file, this is only a demo.

In [None]:
import torch.nn as nn
import torch

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1)
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x
    
model = Net()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)


num_epochs = 5
for epoch in range(num_epochs):
    model.train(True)
    for i, x in enumerate(train_loader):
        optimizer.zero_grad()
        y = model(x)
        loss = loss_fn(y, x)
        loss.backward()
        optimizer.step()
        print(f'Epoch {epoch+1}, Batch {i+1}, Loss: {loss.item()}')
    
    model.eval()
    ## track validation metrics, save model, etc.