In [1]:
import os
from pathlib import Path
from typing import Optional

import fastmri
import h5py
import matplotlib.pyplot as plt
import numpy as np
import torch
from data_utils import *
from datasets import *
from fastmri.data.transforms import tensor_to_complex_np
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from torch.utils.data import DataLoader, TensorDataset

from model import *
from torch.optim import SGD, Adam, AdamW
from train_utils import *

In [None]:
files = '/itet-stor/mcrespo/bmicdatasets-originals/Originals/fastMRI/brain/multicoil_train/'

dataset = KCoordDataset(files, n_volumes=3, n_slices=3, with_mask=False)
print(len(dataset))
# loader_config = config["dataloader"]
# dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True, collate_fn=collate_fn, pin_memory=PIN_MEMORY, worker_init_fn=seed_worker, generator=RS_TORCH)
dataloader = DataLoader(
    dataset,
    batch_size=120_000,
    num_workers=0,
    shuffle=True,
    pin_memory=True,
    )

14534400


In [19]:
ground_truth = []
n_slices = 3


for vol_id in dataloader.dataset.metadata.keys():
    file = dataloader.dataset.metadata[vol_id]["file"]
    with h5py.File(file, "r") as hf:
        ground_truth.append(
            hf["reconstruction_rss"][()][: n_slices]
        )


In [None]:
model_checkpoint = '/scratch_net/ken/mcrespo/proj_marina/logs/multivol/2024-11-11_09h28m59s/checkpoints/epoch_0999.pt'  # TODO: SET (OR LEAVE COMMENTED).


gamma = 0.1
sigma = 0.01
lr = 5.e-6
embedding_dim = 512

OPTIMIZER_CLASSES = {
    "Adam": Adam,
    "AdamW": AdamW,
    "SGD": SGD,
}

LOSS_CLASSES = {
    "MAE": MAELoss,
    "DMAE": DMAELoss,
    "MSE": MSELoss,
    "MSEDist": MSEDistLoss,
    "HDR": HDRLoss,
    "LogL2": LogL2Loss,
    "MSEL2": MSEL2Loss,
}


model = Siren(hidden_dim=512, embedding_dim=512, L =10, n_layers=8, out_dim=2)
# Load checkpoint.
model_state_dict = torch.load(model_checkpoint, map_location=torch.device('cpu'))
model.load_state_dict(model_state_dict["model_state_dict"])
print("Checkpoint loaded successfully.")

# Only embeddings are optimized.
for param in model.parameters():
    param.requires_grad = False


embeddings = torch.nn.Embedding(
    len(dataset.metadata), embedding_dim
)
torch.nn.init.normal_(
    embeddings.weight.data, 0.0, sigma
)
optimizer = OPTIMIZER_CLASSES["Adam"](
    embeddings.parameters(), lr
)

loss_fn = LOSS_CLASSES["MSEL2"](gamma, sigma)

  model_state_dict = torch.load(model_checkpoint, map_location=torch.device('cpu'))


Checkpoint loaded successfully.


In [23]:
## START TRAINING FROM CHECKPOINT
counter = 0
for inputs, targets in dataloader:
    counter += 1
    # Inputs has dimension Nm x 5, position 0 corresponds to volID
    coords, latent_embeddings = inputs[:, 1:], embeddings(
        inputs[:, 0].long()
    )
    
    if counter > 0:
        break
    
    

In [27]:
print(latent_embeddings[0] - latent_embeddings[1])

tensor([-2.5269e-02,  7.6233e-03,  2.2291e-02,  1.2017e-02, -1.5466e-02,
        -6.5045e-03, -7.5966e-03, -2.1088e-02,  9.0111e-04, -1.0980e-04,
         1.6407e-02,  1.2758e-02, -2.8559e-03,  1.3648e-02,  6.4741e-03,
         1.7174e-02,  1.1482e-02,  6.7015e-03, -1.2104e-02, -1.2050e-02,
         5.2339e-03, -1.9295e-02, -1.5418e-02,  1.9169e-04, -1.3232e-02,
         9.2514e-03, -1.4164e-02, -9.4091e-03,  6.8163e-03, -3.8972e-02,
         1.7190e-02, -2.1136e-02,  8.4777e-03, -1.2556e-02,  6.4036e-03,
         7.4610e-03,  9.5836e-03, -9.9615e-03,  9.4567e-03,  2.5976e-02,
         2.5606e-02, -5.6751e-03,  2.3165e-02, -1.9443e-03, -1.0315e-03,
        -3.2236e-02, -2.5791e-03, -5.0094e-03, -5.8139e-03, -6.6907e-03,
        -1.2893e-02, -4.3120e-03,  6.8851e-03,  1.3315e-02, -3.9535e-02,
         1.3543e-02, -2.8399e-03,  1.1228e-02,  3.7920e-03, -2.1574e-03,
        -1.0655e-02,  6.1027e-03, -1.4650e-02, -4.6833e-03,  1.5268e-02,
        -1.0195e-02,  4.1723e-03,  3.5206e-02, -1.1