<a href="https://colab.research.google.com/github/lhyochan7/MRI-analysis/blob/main/2D_pix2pix.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Mount your Google drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
pip install pydicom

Collecting pydicom
  Downloading pydicom-2.2.2-py3-none-any.whl (2.0 MB)
[K     |████████████████████████████████| 2.0 MB 8.1 MB/s 
[?25hInstalling collected packages: pydicom
Successfully installed pydicom-2.2.2


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import nibabel as nib
from matplotlib import pyplot as plt
import pydicom

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

import os
import time
import glob
import numpy as np
from numpy import savez_compressed
from numpy import load
import torchvision.transforms as transforms

import gc
gc.collect()
torch.cuda.empty_cache()

In [4]:
transform = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize([0.5],[0.5]),
                    transforms.Resize((256,256))
])

In [5]:
class DataLoader():
  def __init__(self, o_path, t_path):
    self.o_dataset = []
    self.t_dataset = []
    self.org_list = []
    self.tar_list = []
    self.o_path = o_path
    self.t_path = t_path

  def preprocess(self):
    for org_file in glob.glob(self.o_path):
      org_base = os.path.basename(org_file)
      org_base = org_base[0:15]
      for tar_file in glob.glob(self.t_path):
        tar_base = os.path.basename(tar_file)
        tar_base = tar_base[0:15]
        if tar_base[0:9] == org_base[0:9]:
          org_img = pydicom.dcmread(org_file)
          tar_img = pydicom.dcmread(tar_file)

          # convert dicom file to numpy array
          data1 = org_img.pixel_array
          data2 = tar_img.pixel_array


          data1 = data1.astype('float32')
          data2 = data2.astype('float32')

          data1 = transform(data1)
          data2 = transform(data2)

          data1 = np.reshape(data1, (1,1,256,256))
          data2 = np.reshape(data2, (1,1,256,256))
          
          self.org_list.append(data1)
          self.tar_list.append(data2)
    
    org_output = self.org_list[0]
    tar_output = self.tar_list[0]

    count = 0
    for o_data in self.org_list:
      if count == 0:
        count += 1
        continue
      org_output = torch.cat([org_output, o_data], 0)
    
    self.o_dataset.append(org_output)

    count = 0
    for t_data in self.tar_list:
      if count == 0:
        count += 1
        continue
      tar_output = torch.cat([tar_output, t_data], 0)

    self.t_dataset.append(tar_output)

    # return original and target dataset
    print(self.o_dataset[0].shape)
    return self.o_dataset, self.t_dataset

In [6]:
data = DataLoader('/content/drive/MyDrive/GAN_Data/ADNI_002_S_1070_2006_result/*.dcm', '/content/drive/MyDrive/GAN_Data/ADNI_002_S_1070_2009_result/*.dcm')
orig, target = data.preprocess()

torch.Size([160, 1, 256, 256])


In [7]:
# UNet
class UNetDown(nn.Module):
    def __init__(self, in_channels, out_channels, normalize=True, dropout=0.0):
        super().__init__()

        layers = [nn.Conv2d(in_channels, out_channels, 4, stride=2, padding=1, bias=False)]

        if normalize:
            layers.append(nn.InstanceNorm2d(out_channels)),

        layers.append(nn.LeakyReLU(0.2))

        if dropout:
            layers.append(nn.Dropout(dropout))

        self.down = nn.Sequential(*layers)

    def forward(self, x):
        x = self.down(x)
        return x

In [8]:
class UNetUp(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=0.0):
        super().__init__()

        layers = [
            nn.ConvTranspose2d(in_channels, out_channels,4,2,1,bias=False),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU()
        ]

        if dropout:
            layers.append(nn.Dropout(dropout))

        self.up = nn.Sequential(*layers)

    def forward(self,x,skip):
        x = self.up(x)
        x = torch.cat((x,skip),1)
        return x

In [9]:
# generator: 가짜 이미지를 생성합니다.
class GeneratorUNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super().__init__()

        self.down1 = UNetDown(in_channels, 64, normalize=False)
        self.down2 = UNetDown(64,128)                 
        self.down3 = UNetDown(128,256)               
        self.down4 = UNetDown(256,512,dropout=0.5) 
        self.down5 = UNetDown(512,512,dropout=0.5)      
        self.down6 = UNetDown(512,512,dropout=0.5)             
        self.down7 = UNetDown(512,512,dropout=0.5)              
        self.down8 = UNetDown(512,512,normalize=False,dropout=0.5)

        self.up1 = UNetUp(512,512,dropout=0.5)
        self.up2 = UNetUp(1024,512,dropout=0.5)
        self.up3 = UNetUp(1024,512,dropout=0.5)
        self.up4 = UNetUp(1024,512,dropout=0.5)
        self.up5 = UNetUp(1024,256)
        self.up6 = UNetUp(512,128)
        self.up7 = UNetUp(256,64)
        self.up8 = nn.Sequential(
            nn.ConvTranspose2d(128,1,4,stride=2,padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)

        u1 = self.up1(d8,d7)
        u2 = self.up2(u1,d6)
        u3 = self.up3(u2,d5)
        u4 = self.up4(u3,d4)
        u5 = self.up5(u4,d3)
        u6 = self.up6(u5,d2)
        u7 = self.up7(u6,d1)
        u8 = self.up8(u7)

        return u8
'''
# check
x = torch.randn(160,1,256,256,device=device)
model = GeneratorUNet().to(device)
out = model(x)
print(out.shape)'''

'\n# check\nx = torch.randn(160,1,256,256,device=device)\nmodel = GeneratorUNet().to(device)\nout = model(x)\nprint(out.shape)'

In [10]:
class Dis_block(nn.Module):
    def __init__(self, in_channels, out_channels, normalize=True):
        super().__init__()

        layers = [nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=1)]
        if normalize:
            layers.append(nn.InstanceNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2))
    
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        x = self.block(x)
        return x

In [11]:
# Discriminator은 patch gan을 사용합니다.
# Patch Gan: 이미지를 16x16의 패치로 분할하여 각 패치가 진짜인지 가짜인지 식별합니다.
# high-frequency에서 정확도가 향상됩니다.

class Discriminator(nn.Module):
    def __init__(self, in_channels=1):
        super().__init__()

        self.stage_1 = Dis_block(in_channels*2,64,normalize=False)
        self.stage_2 = Dis_block(64,128)
        self.stage_3 = Dis_block(128,256)
        self.stage_4 = Dis_block(256,512)

        self.patch = nn.Conv2d(512,1,1,padding=1) # 16x16 패치 생성

    def forward(self,a,b):
        x = torch.cat((a,b),1)
        x = self.stage_1(x)
        x = self.stage_2(x)
        x = self.stage_3(x)
        x = self.stage_4(x)
        x = self.patch(x)
        x = torch.sigmoid(x)
        return x

In [12]:
model_gen = GeneratorUNet().to(device)
model_dis = Discriminator().to(device)

In [13]:
# 가중치 초기화
def initialize_weights(model):
    class_name = model.__class__.__name__
    if class_name.find('Conv') != -1:
        nn.init.normal_(model.weight.data, 0.0, 0.02)


# 가중치 초기화 적용
model_gen.apply(initialize_weights);
model_dis.apply(initialize_weights);

In [14]:
# 손실함수
loss_func_gan = nn.BCELoss()
loss_func_pix = nn.L1Loss()

# loss_func_pix 가중치
lambda_pixel = 100

# patch 수
patch = (1,256//2**4,256//2**4)

# 최적화 파라미터
from torch import optim
lr = 2e-4
beta1 = 0.5
beta2 = 0.999

opt_dis = optim.Adam(model_dis.parameters(),lr=lr,betas=(beta1,beta2))
opt_gen = optim.Adam(model_gen.parameters(),lr=lr,betas=(beta1,beta2))

In [None]:
# 학습
model_gen.train()
model_dis.train()

batch_count = 0
num_epochs = 100
start_time = time.time()

loss_hist = {'gen':[],
             'dis':[]}

for epoch in range(num_epochs):
    for a, b in zip(orig, target):
        ba_si = a.size(0)

        # real image
        real_a = a.to(device)
        real_b = b.to(device)

        # patch label
        real_label = torch.ones(ba_si, *patch, requires_grad=False).to(device)
        fake_label = torch.zeros(ba_si, *patch, requires_grad=False).to(device)

        # generator
        model_gen.zero_grad()

        fake_b = model_gen(real_a) # 가짜 이미지 생성
        out_dis = model_dis(fake_b, real_b) # 가짜 이미지 식별

        gen_loss = loss_func_gan(out_dis, real_label)
        pixel_loss = loss_func_pix(fake_b, real_b)

        g_loss = gen_loss + lambda_pixel * pixel_loss
        g_loss.backward()
        opt_gen.step()

        # discriminator
        model_dis.zero_grad()

        out_dis = model_dis(real_b, real_a) # 진짜 이미지 식별
        real_loss = loss_func_gan(out_dis,real_label)
        
        out_dis = model_dis(fake_b.detach(), real_a) # 가짜 이미지 식별
        fake_loss = loss_func_gan(out_dis,fake_label)

        d_loss = (real_loss + fake_loss) / 2.
        d_loss.backward()
        opt_dis.step()

        loss_hist['gen'].append(g_loss.item())
        loss_hist['dis'].append(d_loss.item())

        batch_count += 1
        if batch_count % 1 == 0:
            print('Epoch: %.0f, G_Loss: %.6f, D_Loss: %.6f, time: %.2f min' %(epoch, g_loss.item(), d_loss.item(), (time.time()-start_time)/60))

torch.Size([160, 64, 128, 128])
torch.Size([160, 128, 64, 64])
torch.Size([160, 256, 32, 32])
torch.Size([160, 512, 16, 16])
torch.Size([160, 512, 8, 8])
torch.Size([160, 512, 4, 4])
torch.Size([160, 512, 2, 2])
torch.Size([160, 512, 1, 1])
torch.Size([160, 512, 2, 2])
torch.Size([160, 512, 4, 4])
torch.Size([160, 512, 8, 8])
torch.Size([160, 512, 16, 16])
torch.Size([160, 256, 32, 32])
torch.Size([160, 128, 64, 64])
torch.Size([160, 64, 128, 128])
