In [95]:
import os
import random
from pathlib import Path
from typing import Union

import h5py
import numpy as np
import torch
from datasets import KCoordDataset, seed_worker
from data_utils import *
from fastmri.data.subsample import EquiSpacedMaskFunc, RandomMaskFunc
from fastmri.data.transforms import tensor_to_complex_np, to_tensor
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt


file = '/itet-stor/mcrespo/bmicdatasets-originals/Originals/fastMRI/brain/multicoil_train/file_brain_AXT1POST_203_6000861.h5'
n_volumes= 1
n_slices = 8
with_mask= True
acceleration = 4
center_frac = 0.15

with h5py.File(file, "r") as hf:
    ground_truth = hf["reconstruction_rss"][()][: n_slices]
    volume_kspace = to_tensor(preprocess_kspace(hf["kspace"][()]))[:n_slices]


Model Architecture
 - Layers are composed of sinusoidal activation functions
 - Network is a concatenated set of 8 sinusoidal layers, hidden dimension 512
 - Fourier encoding

In [165]:
from torch import nn

class SineLayer(nn.Module):
    """Linear layer with sine activation. Adapted from Siren repo"""

    def __init__(
        self, in_features, out_features, bias=True, is_first=False, omega_0=30
    ):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first
        self.in_features = in_features

        self.linear = nn.Linear(in_features, out_features, bias=bias)
        # self.linear = nn.utils.weight_norm(nn.Linear(in_features, out_features, bias=bias))
        # self.layer_norm = nn.LayerNorm(out_features)
        # self.batch_norm = nn.BatchNorm1d(out_features)
        self.init_weights()

    def init_weights(self): # initialization function to initialize the weights in the first layer and later layers with normalized criterion
        with torch.no_grad():
            if self.is_first:
                # self.linear.weight.uniform_ - modifies the current values of the weights and they are set to a normal distribution (-1/n, 1/n)
                self.linear.weight.uniform_(-1 / self.in_features, 1 / self.in_features)
            else: # The rest of layers are initialized with the following random values in the uniform distribution U()
                self.linear.weight.uniform_(
                    -np.sqrt(6 / self.in_features) / self.omega_0,
                    np.sqrt(6 / self.in_features) / self.omega_0,
                )

    def forward(self, x):
        # NOTE: Uncomment when using batch (or layer) normalization.
        # x = self.linear(x)
        # x = self.layer_norm(x)
        # x = self.batch_norm(x)
        # return torch.sin(self.omega_0 * x)

        return torch.sin(self.omega_0 * self.linear(x))
    

class Fourier_Enc(nn.Module):
    def __init__(self, c, L, initial_beta=10.0):
        
        super().__init__()
        self.L = L
        self.c = c
        
        # Trainable beta parameter
        self.beta = nn.Parameter(torch.tensor(initial_beta, dtype=torch.float32))
        # Is beta going to be trained??
        self.B = torch.FloatTensor(self.c , self.L).uniform_(-1, 1) 
        
        self.register_buffer('Bmatrix', self.B)
        self.scaled_B = self.B * self.beta
        
    def forward(self, coor):
        # Scale B matrix by the trainable beta
        
        # Compute Fourier angles: 2 * pi * (coor @( scaled_B)
        fourier_angle = (2 * np.pi * coor.unsqueeze(-1) * self.scaled_B)  # (n x 4) @ (4 x len_enc)

        # Apply cosine and sine, then concatenate along the last dimension
        cos_encoding = torch.cos(fourier_angle)  # (n x len_enc)
        sin_encoding = torch.sin(fourier_angle)  # (n x len_enc)

        # Concatenate both along the last dimension to get the desired shape
        encoded_fourier = torch.cat([cos_encoding, sin_encoding], dim=-1)  # (n x (2 * len_enc))
        return encoded_fourier       
        
                
class Siren(nn.Module):
    def __init__(self,      
            c: int=4 , 
            L: int=10, 
            hidden_dim: int=512, 
            n_layers: int=8, 
            out_dim: int=2, 
            omega_0: int=30,
            outermost_linear=False,
            ) -> None:
        
        super().__init__()
        
        # self.sine_layers = nn.ModuleList()
        self.L = L
        self.c = c
        fourier_dim = self.L*2*self.c
        
        self.sine_layers = nn.ModuleList()
        self.sine_layers.append(SineLayer(fourier_dim, hidden_dim, is_first=True, omega_0=omega_0))

        for i in range(n_layers-1):
            self.sine_layers.append(
                SineLayer(hidden_dim, hidden_dim, is_first=False, omega_0=omega_0)
            )
        
        # Regarding the last layer
        if outermost_linear:
            self.final_layer = nn.Linear(hidden_dim, out_dim)
            # For initialization purposes, don't keep track of the weights when you initialize them:
            with torch.no_grad():
                self.final_layer.weight.uniform_(-np.sqrt(6 / hidden_dim) / omega_0, np.sqrt(6 / hidden_dim) / omega_0)
            
        else:
            self.final_layer = SineLayer(hidden_dim, out_dim, is_first=False, omega_0=omega_0)
            
        self.sine_layers.append(self.final_layer)
        
        # self.sine_layers = nn.Sequential(self.sine_layers)
        # self.sine_layers = nn.ModuleList(self.sine_layers)
        # self.output_layer = nn.Linear(hidden_dim, out_dim)
        
    def forward(self, coords):
        encoder = Fourier_Enc(self.c, self.L)
        x = encoder(coords)
        for layer in self.sine_layers:
            x = layer(x)
        return x



Train the model 

In [151]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

# Assuming the Fourier_Enc and Siren classes are defined above this code

# Hyperparameters
len_enc = 10
initial_beta = 1
learning_rate = 0.001
num_epochs = 500
batch_size = 1_000
n_checkpoint = 100

# Initialize the model, dataset, and dataloader
fourier_encoder = Fourier_Enc(c=4,L=len_enc, initial_beta=initial_beta)
model = Siren(outermost_linear=True)
print(model)
# Move model to the appropriate device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fourier_encoder.to(device)
model.to(device)

# Define the optimizer and loss function
optimizer = optim.Adam(list(fourier_encoder.parameters()) + list(model.parameters()), lr=learning_rate)
loss_fn = nn.MSELoss()  # Example loss function; modify as needed

# Assuming you have a dataset class KCoordDataset defined and imported
dataset = KCoordDataset(path_to_data=file, n_volumes=1, n_slices=3)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Training loop
for epoch in range(num_epochs):
    avg_loss = 0.0
    n_obs = 0
    count = 0
    for inputs, targets in dataloader:
        # Move inputs and targets to the specified device
        inputs, targets = inputs.to(device), targets.to(device)

        # Zero the gradients
        optimizer.zero_grad(set_to_none=True)

        # Pass encoded inputs through the Siren model
        outputs = model(inputs)

        # Calculate the loss
        batch_loss = loss_fn(outputs, targets)

        # Backward pass and optimization
        batch_loss.backward()
        optimizer.step()

        # Update average loss and number of observations
        avg_loss += batch_loss.item() * len(inputs)
        n_obs += len(inputs)
            
        count += 1

    # Calculate average loss for the epoch
    avg_loss /= n_obs
    if epoch % n_checkpoint == 0:
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}")

# Save the model if needed
torch.save(model.state_dict(), 'siren_model.pth')



Siren(
  (sine_layers): ModuleList(
    (0): SineLayer(
      (linear): Linear(in_features=20, out_features=512, bias=True)
    )
    (1-7): 7 x SineLayer(
      (linear): Linear(in_features=512, out_features=512, bias=True)
    )
    (8): Linear(in_features=512, out_features=2, bias=True)
  )
  (final_layer): Linear(in_features=512, out_features=2, bias=True)
)
Epoch [1/500], Loss: 0.0066


KeyboardInterrupt: 

In [160]:
L_mult = torch.pow(2, torch.arange(10)) * np.pi
x = kspace_coords.unsqueeze(-1) * L_mult
x = torch.cat([torch.sin(x), torch.cos(x)], dim=-1)
x = x.view(x.size(0), -1)
print(x[0,...].shape)

torch.Size([80])


In [164]:
(2 * np.pi * coor.unsqueeze(-1) * B_scaled).shape

torch.Size([4, 10])

In [100]:
dataset = KCoordDataset(file,)
# Ground truth (used to compute the evaluation metrics).
dataloader = DataLoader(dataset, batch_size=120, num_workers=3, shuffle=True)
file = dataloader.dataset.metadata[0]["file"]

with h5py.File(file, "r") as hf:
    ground_truth = hf["reconstruction_rss"][()][
        : n_slices]
    
# plt.subplot(1, 2, 1)
# plt.imshow(ground_truth[0,...])
# plt.axis('off')

avg_loss = 0.0
n_obs = 0

class Trainer():
    def __init__(
        self, device, model, loss_fn, optimizer, scheduler, config,
    ):
        self.model = model
        
        if hasattr(loss_fn, "to"):
            self.loss_fn = loss_fn.to(self.device)
        else:
            self.loss_fn = loss_fn
            
        self.optimizer = optimizer
        self.scheduler = scheduler
        
        self.model.train()
    
    def train(self):
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(self.device), targets.to(self.device)

            self.optimizer.zero_grad(set_to_none=True)

            outputs = self.model(inputs)
            # Can be thought as a moving average (with "stride" `batch_size`) of the loss.
            batch_loss = self.loss_fn(outputs, targets)
            # NOTE: Uncomment for some of the loss functions (e.g. 'MSEDistLoss').
            # batch_loss = self.loss_fn(outputs, targets, inputs)

            batch_loss.backward()
            self.optimizer.step()

            avg_loss += batch_loss.item() * len(inputs)
            n_obs += len(inputs)

        self.scheduler.step()
        avg_loss = avg_loss / n_obs
        return avg_loss

n_epochs = 100
empirical_risk = 0
for epoch_idx in range(n_epochs):

    print(f"EPOCH {epoch_idx}    avg loss: {empirical_risk}\n")
