In [None]:
# import necessary libraries

import os
import yaml
import torch
import random
import numpy as np
import torchio as tio
from models.networks import *
from easydict import EasyDict
from torch.utils.data import DataLoader

In [None]:
# necessary utility functions

def get_data(data_path):
    subjects = []
    data = np.load(os.path.join(data_path))
    # load data
    t1, t2 = data[:1], data[1:2]
    subject = tio.Subject(
        t1 = tio.ScalarImage(tensor=t1),
        t2 = tio.ScalarImage(tensor=t2)
    )
    subjects.append(subject)
    dataset = tio.SubjectsDataset(subjects)
    return dataset

In [None]:
# load saved model

with open('core/config.yaml') as f:
    config = EasyDict(yaml.safe_load(f))

# define model
enc_s = Encoder(inc=config.dataset.nmodal, zdim=config.model.latent_dim).cuda()
enc_s = nn.DataParallel(enc_s, device_ids=config.misc.devices).cuda()
enc_c = Encoder(inc=config.dataset.nmodal, zdim=config.model.latent_dim).cuda()
enc_c = nn.DataParallel(enc_c, device_ids=config.misc.devices).cuda()
dec = Decoder(outc=config.dataset.nmodal, zdim=config.model.latent_dim).cuda()
dec = nn.DataParallel(dec, device_ids=config.misc.devices).cuda()

state_dict = torch.load('results/checkpoint.pth')
enc_s.load_state_dict(state_dict['enc_s'])
enc_c.load_state_dict(state_dict['enc_c'])
dec.load_state_dict(state_dict['dec'])

enc_s.requires_grad_(False)
enc_c.requires_grad_(False)
dec.requires_grad_(False)

enc_s.eval()
enc_c.eval()
dec.eval()

In [None]:
# run inference

config.dataset.src_dir = 'data/processed/6m'
config.dataset.dst_dir = 'data/processed/12m'

src_names = os.listdir(config.dataset.src_dir)
dst_names = os.listdir(config.dataset.dst_dir)

dst_name = random.sample(dst_names, 1)[0]
for src_name in src_names:
    src = get_data(os.path.join(config.dataset.src_dir, src_name))[0]
    dst = get_data(os.path.join(config.dataset.dst_dir, dst_name))[0]
    transform = tio.CropOrPad(src.shape[1:])
    dst = transform(dst)
    # define sampler and aggragator
    sampler_src = tio.inference.GridSampler(src, config.model.patch_size, config.test.patch_overlap)
    loader_src = DataLoader(sampler_src, config.test.batch_size)
    sampler_dst = tio.inference.GridSampler(dst, config.model.patch_size, config.test.patch_overlap)
    loader_dst = DataLoader(sampler_dst, config.test.batch_size)
    aggregator = tio.inference.GridAggregator(sampler_src, 'average')
    # extract patch
    for i, (patch_src, patch_dst) in enumerate(zip(loader_src, loader_dst)):
        src = torch.cat([patch_src['t1'][tio.DATA], patch_src['t2'][tio.DATA]], 1).cuda()
        dst = torch.cat([patch_dst['t1'][tio.DATA], patch_dst['t2'][tio.DATA]], 1).cuda()
        with torch.cuda.amp.autocast():
            # style code
            s_y = enc_s(dst)
            # content code
            c_x, emb_x = enc_c(src, True)
            # transfer
            syn_y = dec(c_x + s_y, emb_x)
        # install
        loc = patch_src[tio.LOCATION]
        aggregator.add_batch(syn_y, loc)
    print(i+1 == len(loader_src))
    # syn is src content with dst style
    syn = aggregator.get_output_tensor().cpu().numpy()
    # concat segmentation
    src = np.load(os.path.join(config.dataset.src_dir, src_name))
    seg = src[-1:]
    syn = syn * (seg > 0)
    syn = np.concatenate([syn, seg])
    # store
    np.save(os.path.join(config.test.out_dir, os.path.basename(config.dataset.src_dir), src_name), syn)