**MRI to CT conversion using paired and unpaired data-set**

For this project we are using gans i.e. Generative Adversial Network</br>
To implement our model we are using pytorch as neural network library.<br/>These are all set of libraries / dependencies that we need to import:


In [None]:
!pip install albumentations==0.4.6
import albumentations as A
from albumentations.pytorch import ToTensor
from PIL import Image
import os, random, copy, os, sys, torch
from torch.utils.data import Dataset
import numpy as np
from albumentations.pytorch import ToTensor
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim
from tqdm import tqdm
from torchvision.utils import save_image
import matplotlib.pyplot as plt
from google.colab import drive
drive.mount('/content/drive')

Collecting albumentations==0.4.6
  Downloading albumentations-0.4.6.tar.gz (117 kB)
[K     |████████████████████████████████| 117 kB 5.1 MB/s 
Collecting imgaug>=0.4.0
  Downloading imgaug-0.4.0-py2.py3-none-any.whl (948 kB)
[K     |████████████████████████████████| 948 kB 43.2 MB/s 
Building wheels for collected packages: albumentations
  Building wheel for albumentations (setup.py) ... [?25l[?25hdone
  Created wheel for albumentations: filename=albumentations-0.4.6-py3-none-any.whl size=65174 sha256=afba2568ea71f8bbb36e73e77ec23730aadaf6af3eb44260142dafb0900f9b21
  Stored in directory: /root/.cache/pip/wheels/cf/34/0f/cb2a5f93561a181a4bcc84847ad6aaceea8b5a3127469616cc
Successfully built albumentations
Installing collected packages: imgaug, albumentations
  Attempting uninstall: imgaug
    Found existing installation: imgaug 0.2.9
    Uninstalling imgaug-0.2.9:
      Successfully uninstalled imgaug-0.2.9
  Attempting uninstall: albumentations
    Found existing installation: album

**Discriminator Block**

In [None]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias=True, padding_mode="reflect"),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True),
        )

    def forward(self, x):
        return self.conv(x)


class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(
                in_channels,
                features[0],
                kernel_size=4,
                stride=2,
                padding=1,
                padding_mode="reflect",
            ),
            nn.LeakyReLU(0.2, inplace=True),
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(Block(in_channels, feature, stride=1 if feature==features[-1] else 2))
            in_channels = feature
        layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = self.initial(x)
        return torch.sigmoid(self.model(x))

**Generator Block**

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True) if use_act else nn.Identity()
        )

    def forward(self, x):
        return self.conv(x)

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            ConvBlock(channels, channels, kernel_size=3, padding=1),
            ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1),
        )

    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self, img_channels, num_features = 64, num_residuals=9):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
            # nn.InstanceNorm2d(num_features),
            nn.ReLU(inplace=True),
        )
        self.down_blocks = nn.ModuleList(
            [
                ConvBlock(num_features, num_features*2, kernel_size=3, stride=2, padding=1),
                ConvBlock(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1),
            ]
        )
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(num_features*4) for _ in range(num_residuals)]
        )
        self.up_blocks = nn.ModuleList(
            [
                ConvBlock(num_features*4, num_features*2, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
                ConvBlock(num_features*2, num_features*1, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
            ]
        )

        self.last = nn.Conv2d(num_features*1, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")

    def forward(self, x):
        x = self.initial(x)
        for layer in self.down_blocks:
            x = layer(x)
        x = self.res_blocks(x)
        for layer in self.up_blocks:
            x = layer(x)
        return torch.tanh(self.last(x))

**Class for Data Pre-processing**

In [None]:
class MriCtDataset(Dataset):
    def __init__(self, root_ct, root_mri, transform=None):
        self.root_ct = root_ct
        self.root_mri = root_mri
        self.transform = transform

        self.ct_images = os.listdir(root_ct)
        self.mri_images = os.listdir(root_mri)
        self.lengtM_dataset = len(self.mri_images) # 1000, 1500
        self.ct_images.sort()
        self.mri_images.sort()
        self.ct_len = len(self.ct_images)
        self.mri_len = len(self.mri_images)

    def __len__(self):
        return self.lengtM_dataset

    def __getitem__(self, index):
        ct_img = self.ct_images[index % self.ct_len]
        mri_img = self.mri_images[index % self.mri_len]

        ct_path = os.path.join(self.root_ct, ct_img)
        mri_path = os.path.join(self.root_mri, mri_img)

        ct_img = np.array(Image.open(ct_path).convert("RGB"))
        mri_img = np.array(Image.open(mri_path).convert("RGB"))

        if self.transform:
            augmentations = self.transform(image=ct_img, image0=mri_img)
            ct_img = augmentations["image"]
            mri_img = augmentations["image0"]
            #print("inside L" + str(index) + "ct Return -> " + ct_path + "mri Return -> " + mri_path)
            
        return ct_img, mri_img

**All constants values**

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
PUP_DIR = "/content/drive/MyDrive/data/"
CYCLE_DIR = "/content/drive/MyDrive/data_c/"
BATCH_SIZE = 1
LEARNING_RATE = 1e-4
LAMBDA_IDENTITY = 0.0
LAMBDA_CYCLE = 10
LAMBDA_PAIRED = 30
NUM_WORKERS = 4
NUM_EPOCHS = 1
LOAD_MODEL = True
SAVE_MODEL = False
CHECKPOINT_GEN_M_PUP = PUP_DIR + "saved_checkp/genh.pth.tar"
CHECKPOINT_GEN_C_PUP = PUP_DIR + "saved_checkp/genz.pth.tar"
CHECKPOINT_CRITIC_M_PUP = PUP_DIR + "saved_checkp/critich.pth.tar"
CHECKPOINT_CRITIC_C_PUP = PUP_DIR + "saved_checkp/criticz.pth.tar"
CHECKPOINT_GEN_M_CYCLE = CYCLE_DIR + "saved_checkp/genh.pth.tar"
CHECKPOINT_GEN_C_CYCLE = CYCLE_DIR + "saved_checkp/genz.pth.tar"
CHECKPOINT_CRITIC_M_CYCLE = CYCLE_DIR + "saved_checkp/critich.pth.tar"
CHECKPOINT_CRITIC_C_CYCLE = CYCLE_DIR + "saved_checkp/criticz.pth.tar"

transforms = A.Compose(
    [
        A.Resize(width=256, height=256),
        A.HorizontalFlip(p=0.5),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensor(),
     ],
    additional_targets={"image0": "image"},
)

**To save and load checkpoint**

In [None]:
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr


def seed_everything(seed=42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

**Traning function for PUP Gans**

In [None]:

def train_fn(disc_M, disc_C, gen_C, gen_M, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler, nature):
    M_reals = 0
    M_fakes = 0
    loop = tqdm(loader, leave=True)

    for idx, (ct, mri) in enumerate(loop):
        ct = ct.to(DEVICE)
        mri = mri.to(DEVICE)

        # Train Discriminators H and Z
        with torch.cuda.amp.autocast():
            fake_mri = gen_M(ct)
            D_M_real = disc_M(mri)
            D_M_fake = disc_M(fake_mri.detach())
            M_reals += D_M_real.mean().item()
            M_fakes += D_M_fake.mean().item()
            D_M_real_loss = mse(D_M_real, torch.ones_like(D_M_real))
            D_M_fake_loss = mse(D_M_fake, torch.zeros_like(D_M_fake))
            D_M_loss = D_M_real_loss + D_M_fake_loss

            fake_ct = gen_C(mri)
            D_C_real = disc_C(ct)
            D_C_fake = disc_C(fake_ct.detach())
            D_C_real_loss = mse(D_C_real, torch.ones_like(D_C_real))
            D_C_fake_loss = mse(D_C_fake, torch.zeros_like(D_C_fake))
            D_C_loss = D_C_real_loss + D_C_fake_loss

            # put it togethor
            D_loss = (D_M_loss + D_C_loss)/2

        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train Generators H and Z
        with torch.cuda.amp.autocast():
            # adversarial loss for both generators
            D_M_fake = disc_M(fake_mri)
            D_C_fake = disc_C(fake_ct)
            loss_G_M = mse(D_M_fake, torch.ones_like(D_M_fake))
            loss_G_C = mse(D_C_fake, torch.ones_like(D_C_fake))

            # cycle loss
            cycle_ct = gen_C(fake_mri)
            cycle_mri = gen_M(fake_ct)
            cycle_ct_loss = l1(ct, cycle_ct)
            cycle_mri_loss = l1(mri, cycle_mri)

            # identity loss (remove these for efficiency if you set lambda_identity=0)
            identity_ct = gen_C(ct)
            identity_mri = gen_M(mri)
            identity_ct_loss = l1(ct, identity_ct)
            identity_mri_loss = l1(mri, identity_mri)
            
            # add all togethor
            G_loss = (
                loss_G_C
                + loss_G_M
                + cycle_ct_loss * LAMBDA_CYCLE
                + cycle_mri_loss * LAMBDA_CYCLE
                + identity_mri_loss * LAMBDA_IDENTITY
                + identity_ct_loss * LAMBDA_IDENTITY
            )

            # paired loss
            if(nature == True):
                P_C = l1(ct,fake_ct)
                P_M = l1(mri,fake_mri)

                P_loss = P_C + P_M
                G_loss = G_loss + (P_loss * LAMBDA_PAIRED)                     

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        if idx % 400 == 0 and nature == False:
            save_image(mri*0.5+0.5, PUP_DIR + f"saved_image/MR_UPR{idx}.png")
            save_image(ct*0.5+0.5, PUP_DIR + f"saved_image/CT_UPR{idx}.png")
            save_image(fake_mri*0.5+0.5, PUP_DIR + f"saved_image/MR_UP{idx}.png")
            save_image(fake_ct*0.5+0.5, PUP_DIR + f"saved_image/CT_UP{idx}.png")

        if nature == True:
            save_image(mri*0.5+0.5, PUP_DIR + f"saved_image/MR_PR{idx}.png")
            save_image(ct*0.5+0.5, PUP_DIR + f"saved_image/CT_PR{idx}.png")
            save_image(fake_mri*0.5+0.5, PUP_DIR + f"saved_image/MR_P{idx}.png")
            save_image(fake_ct*0.5+0.5, PUP_DIR + f"saved_image/CT_P{idx}.png")

        loop.set_postfix(M_real=M_reals/(idx+1), M_fake=M_fakes/(idx+1))

**Training function for cycle gans**

In [None]:
def train_fn_cycle(disc_M, disc_C, gen_C, gen_M, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler):
    M_reals = 0
    M_fakes = 0
    loop = tqdm(loader, leave=True)

    for idx, (ct, mri) in enumerate(loop):
        ct = ct.to(DEVICE)
        mri = mri.to(DEVICE)

        # Train Discriminators H and Z
        with torch.cuda.amp.autocast():
            fake_mri = gen_M(ct)
            D_M_real = disc_M(mri)
            D_M_fake = disc_M(fake_mri.detach())
            M_reals += D_M_real.mean().item()
            M_fakes += D_M_fake.mean().item()
            D_M_real_loss = mse(D_M_real, torch.ones_like(D_M_real))
            D_M_fake_loss = mse(D_M_fake, torch.zeros_like(D_M_fake))
            D_M_loss = D_M_real_loss + D_M_fake_loss

            fake_ct = gen_C(mri)
            D_C_real = disc_C(ct)
            D_C_fake = disc_C(fake_ct.detach())
            D_C_real_loss = mse(D_C_real, torch.ones_like(D_C_real))
            D_C_fake_loss = mse(D_C_fake, torch.zeros_like(D_C_fake))
            D_C_loss = D_C_real_loss + D_C_fake_loss

            # put it togethor
            D_loss = (D_M_loss + D_C_loss)/2

        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train Generators H and Z
        with torch.cuda.amp.autocast():
            # adversarial loss for both generators
            D_M_fake = disc_M(fake_mri)
            D_C_fake = disc_C(fake_ct)
            loss_G_M = mse(D_M_fake, torch.ones_like(D_M_fake))
            loss_G_C = mse(D_C_fake, torch.ones_like(D_C_fake))

            # cycle loss
            cycle_ct = gen_C(fake_mri)
            cycle_mri = gen_M(fake_ct)
            cycle_ct_loss = l1(ct, cycle_ct)
            cycle_mri_loss = l1(mri, cycle_mri)

            # identity loss (remove these for efficiency if you set lambda_identity=0)
            identity_ct = gen_C(ct)
            identity_mri = gen_M(mri)
            identity_ct_loss = l1(ct, identity_ct)
            identity_mri_loss = l1(mri, identity_mri)
            
            # add all togethor
            G_loss = (
                loss_G_C
                + loss_G_M
                + cycle_ct_loss * LAMBDA_CYCLE
                + cycle_mri_loss * LAMBDA_CYCLE
                + identity_mri_loss * LAMBDA_IDENTITY
                + identity_ct_loss * LAMBDA_IDENTITY
            )                

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        if idx % 200 == 0:
            save_image(mri*0.5+0.5, PUP_DIR + f"saved_image/MR_UPR{idx}.png")
            save_image(ct*0.5+0.5, PUP_DIR + f"saved_image/CT_UPR{idx}.png")
            save_image(fake_mri*0.5+0.5, PUP_DIR + f"saved_image/MR_UP{idx}.png")
            save_image(fake_ct*0.5+0.5, PUP_DIR + f"saved_image/CT_UP{idx}.png")

        loop.set_postfix(M_real=M_reals/(idx+1), M_fake=M_fakes/(idx+1))

**Test Function**

In [None]:
def test_fn(gen_C, gen_M, loader, l1):

    loop = tqdm(loader, leave=True)
    for idx, (ct, mri) in enumerate(loop):
        ct = ct.to(DEVICE)
        mri = mri.to(DEVICE)
        
        with torch.cuda.amp.autocast():
            fake_ct = gen_C(mri)
        save_image(mri*0.5+0.5, PUP_DIR + f"saved_image/{idx}_MR_Real.png")
        save_image(ct*0.5+0.5, PUP_DIR + f"saved_image/{idx}_CT_Real.png")
        save_image(fake_ct*0.5+0.5, PUP_DIR + f"saved_image/{idx}_CT_Fake.png")

**Main function**

In [None]:
def main():
    disc_M = Discriminator(in_channels=3).to(DEVICE)
    disc_C = Discriminator(in_channels=3).to(DEVICE)
    gen_C = Generator(img_channels=3, num_residuals=9).to(DEVICE)
    gen_M = Generator(img_channels=3, num_residuals=9).to(DEVICE)
    opt_disc = optim.Adam(
        list(disc_M.parameters()) + list(disc_C.parameters()),
        lr=LEARNING_RATE,
        betas=(0.5, 0.999),
    )

    opt_gen = optim.Adam(
        list(gen_C.parameters()) + list(gen_M.parameters()),
        lr=LEARNING_RATE,
        betas=(0.5, 0.999),
    )

    L1 = nn.L1Loss()
    mse = nn.MSELoss()

    if LOAD_MODEL:
        load_checkpoint(
            CHECKPOINT_GEN_M_PUP, gen_M, opt_gen, LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_GEN_C_PUP, gen_C, opt_gen, LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_CRITIC_M_PUP, disc_M, opt_disc, LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_CRITIC_C_PUP, disc_C, opt_disc, LEARNING_RATE,
        )

    #datasetUP = MriCtDataset(root_mri=CYCLE_DIR + "MRI", root_ct= CYCLE_DIR + "CT", transform=transforms)
    #datasetP = MriCtDataset(root_mri=CYCLE_DIR + "MRI-Paired", root_ct= CYCLE_DIR + "CT-Paired", transform=transforms)
    datasetP = MriCtDataset(root_mri=CYCLE_DIR + "MRI-Select", root_ct= CYCLE_DIR + "CT-Select", transform=transforms)

    
    '''loaderUP = DataLoader(
        datasetUP,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True
    )'''
    loaderP = DataLoader(
        datasetP,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True
    )
    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()

    for epoch in range(NUM_EPOCHS):
        #train_fn_cycle(disc_M, disc_C, gen_C, gen_M, loaderUP, opt_disc, opt_gen, L1, mse, d_scaler, g_scaler)
        train_fn(disc_M, disc_C, gen_C, gen_M, loaderP, opt_disc, opt_gen, L1, mse, d_scaler, g_scaler, True)
        #print(epoch)

        if SAVE_MODEL:
            save_checkpoint(gen_M, opt_gen, filename=CHECKPOINT_GEN_M_PUP)
            save_checkpoint(gen_C, opt_gen, filename=CHECKPOINT_GEN_C_PUP)
            save_checkpoint(disc_M, opt_disc, filename=CHECKPOINT_CRITIC_M_PUP)
            save_checkpoint(disc_C, opt_disc, filename=CHECKPOINT_CRITIC_C_PUP)

    #test_fn(gen_C, gen_M, loaderP, L1)
if __name__ == "__main__":
    main()