In [None]:
MODEL_PATH = "../model_weights/test_model_capacity/y2ax6dmp/val_loss=-1.61053097&epoch=99.ckpt"
TARGET_PATH = "/media/agjvc_rad3/_TESTKOLLEKTIV/Daten/Daten/*/Series*"
OUTPUT_DIR = "images_reg_av_transmorph"

In [None]:
import gc
import glob
import os
import time
import sys

sys.path.append('../')

from pathlib import Path
import numpy as np
import pydicom
import scipy.io as spio
import shutil
import torch
import torchvision

In [None]:
from reg.data.utils import ZNormalization, RescaleIntensity, read_mat_data_w_meta
from reg.transmorph import TransMorphModule
from reg.transmorph.modules.spatial_transformer import SpatialTransformerSeries

In [None]:
if not torch.cuda.is_available():
    raise Exception("No cuda :(")

model = TransMorphModule.load_from_checkpoint(str(MODEL_PATH), strict=True)

print(f"{'=' * 5} Configuration summary {'=' * 92}")
print(f"")
print(model.hparams)
print(f"")
print("=" * 120)

transforms = torchvision.transforms.Compose([ZNormalization(), RescaleIntensity(0, 1)])

In [None]:
series_folders = glob.glob(TARGET_PATH)
series_paths = [Path(p) for p in series_folders]

In [None]:
deltas = []
for series_path_dir in series_paths:
    start = time.time()

    reg_path_dir = series_path_dir / OUTPUT_DIR
    dicoms_mat_path = series_path_dir / "dicoms.mat"
    os.makedirs(reg_path_dir, exist_ok=True)

    print(120 * "=")
    print(f"series_path_dir = {series_path_dir}")
    print(f"reg_path_dir = {reg_path_dir}")
    print(f"dicoms_mat_path = {dicoms_mat_path}")

    data, mat = read_mat_data_w_meta(dicoms_mat_path)
    sample = torch.from_numpy(data.astype(np.float32)).unsqueeze(0).cuda()
    sample_transformed = transforms(sample)

    with torch.no_grad():
        _, flow, _ = model(sample_transformed)

    stn = SpatialTransformerSeries(sample.shape[2:]).cuda()
    warped = stn(sample, flow)
    dcm = warped.view(warped.shape[2:]).permute(2, 0, 1).detach().cpu().numpy()
    mat["dcm"] = dcm

    del stn, sample, sample_transformed, warped, flow, dcm

    reg_result_path = reg_path_dir / "dicoms.mat"
    spio.savemat(reg_result_path, mat, long_field_names=True)

    shutil.copyfile(
        series_path_dir / "images_reg_av/dicomsNumber.mat",
        reg_path_dir / "dicomsNumber.mat",
    )
    dicom_path = reg_path_dir / "IM-0001-0001.dcm"

    shutil.copyfile(series_path_dir / "images_reg_av/IM-0001-0001.dcm", dicom_path)
    ds = pydicom.dcmread(dicom_path)

    data = mat["dcm"][0, :, :]
    assert data.dtype == np.float32, "DICOM pixel data should be float32"

    ds.Rows, ds.Columns = data.shape
    ds.BitsAllocated = 32
    ds.SamplesPerPixel = 1
    ds.PhotometricInterpretation = "MONOCHROME2"
    ds.PixelRepresentation = 1  # For signed data

    ds.PixelData = data.tobytes()
    ds.save_as(dicom_path)

    print("")
    print(f"reg_result_path = {reg_result_path}")
    print(f"dicom_path = {dicom_path}")
    print("")

    gc.collect()
    torch.cuda.empty_cache()
    print(
        "torch.cuda.memory_allocated: %fGB"
        % (torch.cuda.memory_allocated(0) / 1024 / 1024 / 1024)
    )
    print(
        "torch.cuda.memory_reserved: %fGB"
        % (torch.cuda.memory_reserved(0) / 1024 / 1024 / 1024)
    )
    print(
        "torch.cuda.max_memory_reserved: %fGB"
        % (torch.cuda.max_memory_reserved(0) / 1024 / 1024 / 1024)
    )

    end = time.time()
    delta = end - start
    deltas.append(delta)
    print(f"delta = {delta:.4f}s")

In [None]:
np.mean(deltas), np.sum(deltas) / 60