In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from collections import OrderedDict
import random

class Base_Model(nn.Module):
    def __init__(self):
        super(Base_Model, self).__init__()

    @property
    def name(self):
        return 'BaseModel'

    def init_network(self, opt):
        self.opt = opt
        self.gpu_ids = opt.gpu_ids
        self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if (self.gpu_ids and torch.cuda.is_available()) else torch.device('cpu')
        self.save_root = os.path.join(opt.MODEL_SAVE_PATH, self.name, opt.data, opt.tag)
        self.save_dir = os.path.join(self.save_root, 'checkpoint')
        os.makedirs(self.save_dir, exist_ok=True)

        self.loss_names = []
        self.metrics_names = []
        self.model_names = []
        self.visual_names = []
        self.image_paths = []
        self.optimizers = []

    def init_loss(self, opt):
        try:
            self.criterionGAN = GANLoss(use_lsgan=True).to(self.device)
            self.criterionL1 = RestructionLoss(distance='l1').to(self.device)
        except ImportError as e:
            raise ImportError(f"Error initializing loss functions: {e}. Ensure required loss classes are defined.")

    def save_networks(self, which_epoch, total_steps=0, latest=False):
        for name in self.model_names:
            if isinstance(name, str):
                save_filename = '{}_net_{}.pth'.format('latest' if latest else which_epoch, name)
                save_path = os.path.join(self.save_dir, str(which_epoch) if not latest else '', save_filename)
                os.makedirs(os.path.dirname(save_path), exist_ok=True)
                net = getattr(self, 'net' + name)
                param_dict = net.cpu().state_dict()
                torch.save({'iters': total_steps, 'epoch': which_epoch, 'state_dict': param_dict, 'lr': self.optimizers[0].param_groups[0]['lr']}, save_path)

    def load_networks(self, which_epoch, load_path=None, latest=False):
        for name in self.model_names:
            if isinstance(name, str):
                load_filename = '{}_net_{}.pth'.format('latest' if latest else which_epoch, name)
                full_path = os.path.join(load_path if load_path else self.save_dir, load_filename)
                print(f"Loading model from {full_path}")
                save_dict = torch.load(full_path, map_location=self.device)
                state_dict = save_dict['state_dict']
                getattr(self, 'net' + name).load_state_dict(state_dict)


In [None]:
import torch.nn as nn
import functools

class NLayer_3D_Discriminator(nn.Module):
    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm3d, use_sigmoid=False, n_out_channels=1):
        super(NLayer_3D_Discriminator, self).__init__()
        self.n_layers = n_layers
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm3d
        else:
            use_bias = norm_layer == nn.InstanceNorm3d

        kw = 4
        padw = int(np.ceil((kw - 1.0) / 2))
        sequence = [[
            nn.Conv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
            nn.LeakyReLU(0.2, True)
        ]]

        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [[
                nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult,
                          kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [[
            nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult,
                      kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]]

        if use_sigmoid:
            sequence += [[nn.Conv3d(ndf * nf_mult, n_out_channels, kernel_size=kw, stride=1, padding=padw),
                          nn.Sigmoid()]]
        else:
            sequence += [[nn.Conv3d(ndf * nf_mult, n_out_channels, kernel_size=kw, stride=1, padding=padw)]]

        sequence_stream = []
        for n in range(len(sequence)):
            sequence_stream += sequence[n]
        self.model = nn.Sequential(*sequence_stream)

    def forward(self, input):
        return self.model(input)
import torch
import torch.nn as nn

class UNetLike_DownStep5(nn.Module):
    def __init__(self, input_shape, encoder_input_channels, decoder_output_channels, decoder_out_activation):
        super(UNetLike_DownStep5, self).__init__()

        # Encoder 부분: 기본적인 Conv2d 계층 사용
        self.encoder = nn.Sequential(
            nn.Conv2d(encoder_input_channels, 64, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(True)
        )

        # Decoder 부분
        self.decoder = nn.Sequential(
            nn.Conv2d(128, decoder_output_channels, kernel_size=3, padding=1),
            decoder_out_activation()  # 예: ReLU 또는 LeakyReLU
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x
class MultiView_UNetLike_DenseDimensionNet(nn.Module):
    def __init__(self, view1Model, view2Model, view1Order, view2Order, backToSub,
                 decoder_output_channels, decoder_out_activation, decoder_norm_layer=nn.BatchNorm3d):
        super(MultiView_UNetLike_DenseDimensionNet, self).__init__()
        self.view1Model = view1Model
        self.view2Model = view2Model
        self.view1Order = view1Order
        self.view2Order = view2Order
        self.backToSub = backToSub

        self.transposed_layer = Transposed_And_Add(view1Order, view2Order)
        self.decoder_layer = nn.Sequential(
            nn.Conv3d(decoder_output_channels, decoder_output_channels, kernel_size=7, padding=3, bias=False),
            decoder_out_activation()
        )

    def forward(self, inputs):
        """
        inputs: List of [view1_input, view2_input]
        Each input has shape [batch_size, channels, height, width]
        """
        view1_input, view2_input = inputs[0], inputs[1]

        # Debugging: Print the shapes of the inputs to see the actual dimensions
        print(f"view1_input shape: {view1_input.shape}")
        print(f"view2_input shape: {view2_input.shape}")

        batch_size, channels, height, width = view1_input.shape

        # Process through 2D models (no unsqueeze here)
        view1_output = self.view1Model(view1_input)  # Process 4D input directly
        view2_output = self.view2Model(view2_input)  # Process 4D input directly

        # Reshape the output back to the correct dimensions (5D if necessary)
        out_channels = view1_output.size(1)
        view1_output = view1_output.unsqueeze(2)  # Add depth dimension
        view2_output = view2_output.unsqueeze(2)  # Add depth dimension

        # Multi-view fusion
        fused_output = self.transposed_layer(view1_output, view2_output)

        # Return all outputs: view1_output, view2_output, and the final fused_output
        return self.decoder_layer(fused_output)




class Transposed_And_Add(nn.Module):
    def __init__(self, view1Order, view2Order):
        super(Transposed_And_Add, self).__init__()
        self.view1Order = view1Order
        self.view2Order = view2Order

    def forward(self, view1, view2):
        """
        view1: Tensor from view1 model
        view2: Tensor from view2 model
        """
        # Debugging: Print original shapes
        print(f"View1 shape before transpose: {view1.shape}")
        print(f"View2 shape before transpose: {view2.shape}")

        # Adjust view2 dimensions to match view1
        if view1.size(2) != view2.size(2):  # Depth mismatch
            print(f"Adjusting depth: {view2.size(2)} -> {view1.size(2)}")
            view2 = nn.functional.interpolate(
                view2,
                size=(view1.size(2), view1.size(3), view1.size(4)),
                mode='trilinear',
                align_corners=False
            )
        elif view1.size(3) != view2.size(3) or view1.size(4) != view2.size(4):  # Height/Width mismatch
            print(f"Adjusting spatial dimensions: {(view2.size(3), view2.size(4))} -> {(view1.size(3), view1.size(4))}")
            view2 = nn.functional.interpolate(
                view2,
                size=(view1.size(2), view1.size(3), view1.size(4)),
                mode='trilinear',
                align_corners=False
            )

        # Transpose tensors
        view1_transposed = view1.permute(*self.view1Order)
        view2_transposed = view2.permute(*self.view2Order)

        # Debugging: Print transposed shapes
        print(f"View1 shape after transpose: {view1_transposed.shape}")
        print(f"View2 shape after transpose: {view2_transposed.shape}")

        # Compute average
        fused_output = (view1_transposed + view2_transposed) / 2
        return fused_output


# view1Model과 view2Model 정의
# 각 뷰에 대해 모델 정의
view1Model = UNetLike_DownStep5(input_shape=(320, 320), encoder_input_channels=1, decoder_output_channels=64, decoder_out_activation=nn.ReLU)
view2Model = UNetLike_DownStep5(input_shape=(320, 320), encoder_input_channels=1, decoder_output_channels=64, decoder_out_activation=nn.ReLU)

# MultiView_UNetLike_DenseDimensionNet 모델 정의
multi_view_model = MultiView_UNetLike_DenseDimensionNet(
    view1Model=view1Model,
    view2Model=view2Model,
    view1Order=(0, 1, 2, 3, 4),
    view2Order=(0, 1, 2, 3, 4),  # 수정된 부분: 동일한 차원 순서를 사용
    backToSub=True,
    decoder_output_channels=64,
    decoder_out_activation=nn.ReLU
)

# 임의의 입력 데이터 생성 (배치 크기 1, 채널 1, 높이 320, 너비 320)
input_data = [
    torch.randn(1, 1, 320, 320),  # view1_input
    torch.randn(1, 1, 320, 320)   # view2_input
]

# 모델 실행
fused_output = multi_view_model(input_data)

# 결과 출력
print(f"fused_output shape: {fused_output.shape}")
import torch
import torch.nn as nn

# Generator 정의
class Generator3D(nn.Module):
    def __init__(self, input_nc, output_nc, encoder_input_shape, encoder_input_channels, decoder_output_channels, decoder_out_activation):
        super(Generator3D, self).__init__()

        # MultiView_UNetLike_DenseDimensionNet를 사용하여 Generator 구성
        self.encoder_model = UNetLike_DownStep5(
            input_shape=encoder_input_shape,
            encoder_input_channels=encoder_input_channels,
            decoder_output_channels=decoder_output_channels,
            decoder_out_activation=decoder_out_activation
        )

        self.decoder_model = UNetLike_DownStep5(
            input_shape=encoder_input_shape,
            encoder_input_channels=encoder_input_channels,
            decoder_output_channels=decoder_output_channels,
            decoder_out_activation=decoder_out_activation
        )

        # 수정된 view1Order 및 view2Order 정의
        # 5D 텐서에서 [batch, channel, depth, height, width] 순서임을 고려
        self.generator_model = MultiView_UNetLike_DenseDimensionNet(
            view1Model=self.encoder_model,
            view2Model=self.decoder_model,
            view1Order=(0, 1, 2, 3, 4),  # 텐서 순서 그대로 유지
            view2Order=(0, 1, 2, 3, 4),  # 텐서 순서 변경 없이 일치하도록 수정
            backToSub=True,
            decoder_output_channels=output_nc,
            decoder_out_activation=decoder_out_activation
        )

    def forward(self, x):
        # 입력 검증 추가: x는 [view1, view2] 형태로 전달되며, 두 텐서의 크기 일치 필요
        if not isinstance(x, list) or len(x) != 2:
            raise ValueError("Input must be a list containing two tensors: [view1, view2].")
        if x[0].shape != x[1].shape:
            raise ValueError("The shapes of view1 and view2 must match.")
        return self.generator_model(x)
import torch
import torch.nn as nn

# Discriminator 정의
class Discriminator3D(nn.Module):
    def __init__(self, input_nc, ndf, n_layers, norm_layer=nn.BatchNorm3d, use_sigmoid=False):
        super(Discriminator3D, self).__init__()

        # NLayer_3D_Discriminator를 기반으로 Discriminator 구성
        self.discriminator_model = NLayer_3D_Discriminator(
            input_nc=input_nc,
            ndf=ndf,
            n_layers=n_layers,
            norm_layer=norm_layer,
            use_sigmoid=use_sigmoid
        )

    def forward(self, x):
        return self.discriminator_model(x)
from easydict import EasyDict

# 설정 정의
opt = EasyDict({
    'lr': 0.0002,
    'beta1': 0.5,
    'beta2': 0.999,
    'pool_size': 50,
    'idt_lambda': 0.5,
    'gan_lambda': 1.0,
    'input_nc_G': 2,
    'output_nc_G': 1,
    'encoder_input_shape': (320, 320),
    'encoder_input_nc': 1,
    'ndf': 16,
    'n_layers_D': 3,
    'input_nc_D': 3,  # 병합된 3채널 입력
    'CT_MEAN_STD': [0.0, 1.0],
    'XRAY1_MEAN_STD': [0.0, 1.0],
    'XRAY2_MEAN_STD': [0.0, 1.0],
})
class CTGAN(nn.Module):
    def __init__(self, opt, generator, discriminator):
        super(CTGAN, self).__init__()
        self.opt = opt
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Generator와 Discriminator 초기화
        self.netG = generator.to(self.device)
        self.netD = discriminator.to(self.device)

        # 손실 함수 정의
        self.criterionGAN = GANLoss(use_lsgan=True).to(self.device)
        self.criterionL1 = RestructionLoss(distance='l1').to(self.device)

        # Optimizer 정의
        self.optimizer_G = Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
        self.optimizer_D = Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))

        # Fake Image Pool
        self.fake_pool = ImagePool(opt.pool_size)

         # 디버깅 정보를 저장할 속성
        self.debug_info = {}

    def set_input(self, input_data):
        self.xray1 = input_data['front_xray_tensor'].to(self.device)
        self.xray2 = input_data['side_xray_tensor'].to(self.device)
        self.ct_volume = input_data['ct_volume_tensor'].to(self.device)

    def forward(self):
        self.fake_ct = self.netG([self.xray1, self.xray2])

    def backward_D(self):
        # X-ray 텐서를 5D로 확장
        xray1_5d = self.xray1.unsqueeze(2)  # [1, 1, 1, 320, 320]
        xray2_5d = self.xray2.unsqueeze(2)  # [1, 1, 1, 320, 320]

        # X-ray 크기 조정
        xray1_resized = torch.nn.functional.interpolate(
            xray1_5d, size=self.ct_volume.shape[2:], mode='trilinear', align_corners=False
        )
        xray2_resized = torch.nn.functional.interpolate(
            xray2_5d, size=self.ct_volume.shape[2:], mode='trilinear', align_corners=False
        )

        # Fake CT 크기 조정 (직접 정의된 크기로 변경)
        fake_ct_resized = torch.nn.functional.interpolate(
            self.fake_ct, size=(320, 320, 320), mode='trilinear', align_corners=False
        )

        # 디버깅 정보를 저장
        self.debug_info['xray1_resized_shape'] = xray1_resized.shape
        self.debug_info['xray2_resized_shape'] = xray2_resized.shape
        self.debug_info['ct_volume_shape'] = self.ct_volume.shape
        self.debug_info['fake_ct_resized_shape'] = fake_ct_resized.shape

        # Real input 병합
        real_input = torch.cat([xray1_resized, xray2_resized, self.ct_volume], dim=1)  # [1, 3, 320, 320, 320]
        pred_real = self.netD(real_input)
        loss_D_real = self.criterionGAN(pred_real, True)

        # Fake input 병합
        fake_input = torch.cat([xray1_resized, xray2_resized, fake_ct_resized.detach()], dim=1)  # [1, 3, 320, 320, 320]
        fake_input = self.fake_pool.query(fake_input)
        pred_fake = self.netD(fake_input)
        loss_D_fake = self.criterionGAN(pred_fake, False)

        # Loss 계산
        self.loss_D = (loss_D_real + loss_D_fake) * 0.5
        self.loss_D.backward()



    def backward_G(self):
        """
        Generator 역전파
        """
        print(f"xray1 type: {type(self.xray1)}, shape: {self.xray1.shape}")
        print(f"xray2 type: {type(self.xray2)}, shape: {self.xray2.shape}")
        print(f"fake_ct type: {type(self.fake_ct)}, shape: {self.fake_ct.shape}")

    # X-ray 텐서를 5D로 확장 (depth 차원 추가)
        xray1_5d = self.xray1.unsqueeze(2)  # [1, 1, 1, 320, 320]
        xray2_5d = self.xray2.unsqueeze(2)  # [1, 1, 1, 320, 320]

    # Generator에서 생성된 fake_ct를 5D로 확장
        if self.fake_ct.dim() == 5:
            fake_ct_5d = self.fake_ct
        elif self.fake_ct.dim() == 4:
            fake_ct_5d = self.fake_ct.unsqueeze(2)  # [1, 1, 1, 320, 320]
        else:
            raise ValueError(f"Unexpected fake_ct shape: {self.fake_ct.shape}")

        print(f"Adjusted fake_ct shape: {fake_ct_5d.shape}")

    # GAN Loss 계산을 위해 fake_ct와 X-ray 병합
        fake_input = torch.cat([xray1_5d, xray2_5d, fake_ct_5d], dim=1)  # [1, 3, 1, 320, 320]
        pred_fake = self.netD(fake_input)
        self.loss_G_GAN = self.criterionGAN(pred_fake, True)

    # Reconstruction Loss 계산을 위해 fake_ct_5d 크기를 ct_volume과 일치
        fake_ct_resized = torch.nn.functional.interpolate(
            self.fake_ct, size=self.ct_volume.shape[2:], mode='trilinear', align_corners=False
          )  # [1, 1, 320, 320, 320]

        print(f"Resized fake_ct shape for Reconstruction Loss: {fake_ct_resized.shape}")

    # Reconstruction Loss 계산
        self.loss_G_L1 = self.criterionL1(self.ct_volume, fake_ct_resized) * self.opt.idt_lambda

    # 총 Generator Loss
        self.loss_G = self.loss_G_GAN + self.loss_G_L1
        self.loss_G.backward()


    def optimize_parameters(self):
        self.forward()
        self.set_requires_grad(self.netD, False)
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()

        self.set_requires_grad(self.netD, True)
        self.optimizer_D.zero_grad()
        self.backward_D()
        self.optimizer_D.step()

    def set_requires_grad(self, nets, requires_grad=False):
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad

    def get_debug_info(self):
        return self.debug_info



# Generator와 Discriminator 초기화
generator = Generator3D(
    input_nc=opt.input_nc_G,
    output_nc=opt.output_nc_G,
    encoder_input_shape=opt.encoder_input_shape,
    encoder_input_channels=opt.encoder_input_nc,
    decoder_output_channels=opt.output_nc_G,
    decoder_out_activation=nn.Tanh
)

discriminator = Discriminator3D(
    input_nc=opt.input_nc_D,  # 3채널 입력
    ndf=opt.ndf,
    n_layers=opt.n_layers_D
)

# CTGAN 초기화
ctgan = CTGAN(opt, generator, discriminator)

