In [None]:
import os
import sys
import shutil
import pydicom
import numpy as np
import scipy.io as spio
import pydicom
import glob
import torch
from pathlib2 import Path

import gc

In [None]:
from dataset import normalize, standardize, reader
from model import TransMorphModule
from utils import load_losses, load_model_params
from models.modules.spatial_transformer import SpatialTransformerSeries
from pytorch_lightning import Trainer

In [None]:
def load_best_model(ckpt_dir):
    ckpt_dir = Path(ckpt_dir)
    ckpt_file_names = sorted([c.name for c in ckpt_dir.glob("*.ckpt")], reverse=True)
    
    best_ckpt = ckpt_dir / ckpt_file_names[0]
    val_loss, epoch = ckpt_file_names[0].split("&")
    val_loss = val_loss.split("=")[1]
    epoch = epoch.split("=")[1].split(".")[0]

    ident = str(ckpt_dir.name).split("-")
    if len(ident) == 9:
        model_name, image_loss, flow_loss, optimizer_name, lr, target_type, max_epoch, series_len, data_mod = ident
    else:
        model_name, model_ver, image_loss, flow_loss, optimizer_name, lr, target_type, max_epoch, series_len, data_mod = ident
        model_name = f"{model_name}-{model_ver}"

    target_type = str.lower(target_type)
    series_len = int(series_len)
    max_epoch = int(max_epoch)

    image_losses, flow_losses = load_losses(image_loss, flow_loss, delimiter="=")
    
    net, criteria_image, criteria_flow, criterion_disp, optimizer = load_model_params(
        model_name=model_name,
        image_losses=image_losses,
        flow_losses=flow_losses,
        optimizer_name=None,
        series_len=series_len)

    model = TransMorphModule.load_from_checkpoint(
        str(best_ckpt),
        strict=False,
        net=net,
        criteria_image=criteria_image,
        criteria_flow=criteria_flow,
        criterion_disp=criterion_disp,
        target_type=target_type,
    )

    print("=" * 80)
    print(f"Best performing model")
    print("")
    print(f"    val_loss    : {val_loss}")
    print(f"    epoch       : {epoch}")
    print("")
    print(f"    model_name  : {model_name}")
    print(f"    image_loss  : {criteria_image}")
    print(f"    flow_loss   : {criteria_flow}")
    print(f"    disp_loss   : {criterion_disp}")
    print(f"    target_type : {target_type}")
    print(f"    max_epoch   : {max_epoch}")
    print(f"    series_len  : {series_len}")
    print(f"    data_mod    : {data_mod}")
    print("")
    print("=" * 80)

    return model, data_mod

In [None]:
cpkt = "model_weights_v2/transmorph-gmi=1-gl2d=1-adam-0.0001-last-100-32-norm/"
if not torch.cuda.is_available():
    raise Exception("No cuda :(")

In [None]:
series_folders = glob.glob("/media/agjvc_rad3/_TESTKOLLEKTIV/Daten/Daten/*/Series*")
series_paths = [Path(p) for p in series_folders]
model, data_mod = load_best_model(cpkt)

In [None]:
gc.collect()
torch.cuda.empty_cache()
print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
print("torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
print("torch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(0)/1024/1024/1024))

In [None]:
for series_path_dir in series_paths:
    print(80 * "=")
    print(f"series_path_dir = {series_path_dir}")

    reg_path_dir = series_path_dir / "images_reg_av_transmorph"
    print(f"reg_path_dir = {reg_path_dir}")

    if not os.path.exists(reg_path_dir):
        os.makedirs(reg_path_dir)
    else:
        print("  -> reg_path_dir already exists")

    dicoms_mat_path = series_path_dir / "dicoms.mat"
    print(f"dicoms_mat_path = {dicoms_mat_path}")

    data, mat = reader(dicoms_mat_path, mat=True)
    ndat = data
    if data_mod == "std":
        ndat = standardize(data)
    if data_mod == "norm":
        ndat = normalize(data)
    
    sample = torch.from_numpy(ndat.astype(np.float32)).unsqueeze(0).cuda()
    with torch.no_grad():
        _, flows, _ = model(sample)

    del ndat
    del sample

    sample = torch.from_numpy(data.astype(np.float32)).unsqueeze(0).cuda()
    stn = SpatialTransformerSeries(sample.shape[2:]).cuda()
    warped = stn(sample, flows)
    
    del stn
    del sample
    del flows
    
    dcm = warped.view(warped.shape[2:]).permute(2,0,1).detach().cpu().numpy()
    del warped

    mat["dcm"] = dcm
    del dcm

    reg_result_path = reg_path_dir / "dicoms.mat"
    print(f"reg_result_path = {reg_result_path}")
    spio.savemat(reg_result_path , mat, long_field_names=True)

    shutil.copyfile(series_path_dir / 'images_reg_av/dicomsNumber.mat', reg_path_dir / 'dicomsNumber.mat')

    dicom_path = reg_path_dir / "IM-0001-0001.dcm"
    print(f"dicom_path = {dicom_path}")

    shutil.copyfile(series_path_dir / "images_reg_av/IM-0001-0001.dcm", dicom_path)
    ds = pydicom.dcmread(dicom_path)
    data = mat['dcm'][0,:,:]
    ds.PixelData = data.tobytes()
    ds.save_as(dicom_path)

    print()
    
    gc.collect()
    torch.cuda.empty_cache()
    print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
    print("torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
    print("torch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(0)/1024/1024/1024))
    break
    