In [1]:
!conda activate nauman_gpu

**Requiremements**




In [None]:
# conda install pytorch=1.12.0 torchvision=0.13.0 cudatoolkit=11.3.1 -c pytorch
# conda install albumentations=1.2.1 -c conda-forge
# conda install einops=0.4.1 -c conda-forge
# conda install torchsummary=1.5.1 -c ravelbio
# conda install tqdm=4.64.0
# conda install scikit-image=0.18.3
# conda install scikit-learn=1.1.1
# conda install torchmetrics=0.9.3 -c conda-forge

# **Data Downloading**
For training and evaluation we used the BraTS Dataset from RSNA-ASNR-MICCAI BraTS Continuous Evaluation Challenge:
https://www.synapse.org/#!Synapse:syn51156910/wiki/622351

1. Download the data
2. Convert 3d .nii.gz file to 2d
    2a. We used two modalities. i) T1 ii) T1 Contrast
3. Resize the images to 256 256, (Original images in 3D are 240 240 155)
4. We have placed the covnerted 2d resized files here:
https://drive.google.com/file/d/1hXJ8CP6BgJz2bFM1OF2LmGtc3VvRgZhS/view?usp=sharing

# **Preprocessing**


In [5]:
import os
import sys
import numpy as np
import glob
import matplotlib.pyplot as plt
import nibabel as nib
import itertools
from PIL import Image

In [None]:
# Converting T1 3D Nifti images to 2D PNG images 

folder = '/local/data0/home/naumanb/Codes/4_SwinUnetR_BraTS_2023/BraTS2023/TrainingData'

for root, dir, files in os.walk(folder):
    for file in files:
        im = os.path.join(root, file)
        # print(file)

        # if file.endswith('t1c.nii.gz'):
        if file.endswith('t1n.nii.gz'):
            print(file)
            img = os.path.join(root, file)
            img = nib.load(img)
            img = img.get_data().astype(np.float32)
            img_data = np.rot90(img, 3)

            plt.imshow(img_data[:, :, 77], cmap='gray')
            plt.axis('off')
            # Save as PNG
            plt.savefig('/local/data0/home/naumanb/Codes/8_Cx_GAN/data/t1n/' +
                        file+'.png', bbox_inches='tight', pad_inches=0)

            # Show the plot (optional)
            # plt.show()

    # break

In [None]:
# Converting T1-Contrat 3D Nifti images to 2D PNG images

import os
import numpy as np
import matplotlib.pyplot as plt
from skimage.transform import resize
import nibabel as nib

folder = '/local/data0/home/naumanb/Codes/4_SwinUnetR_BraTS_2023/BraTS2023/TrainingData'

for root, dir, files in os.walk(folder):
    for file in files:
        im = os.path.join(root, file)

        if file.endswith('t1c.nii.gz'):
            print(file)
            img_path = os.path.join(root, file)
            img = nib.load(img_path)
            img_data = img.get_data().astype(np.float32)
            img_data = np.rot90(img_data, 3)


            plt.imshow(img_data[:, :, 77], cmap='gray')
            plt.axis('off')

            # Save as PNG
            output_path = '/local/data0/home/naumanb/Codes/8_Cx_GAN/data/t1c/' + file + '.png'
            plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
            # plt.show()

In [None]:
# Resizing the images to 256x256

#Resize the image to 256,256
from PIL import Image
import os, sys

path = "/local/data0/home/naumanb/Codes/8_Cx_GAN/data/t1c/"
dirs = os.listdir( path )
out = "/local/data0/home/naumanb/Codes/8_Cx_GAN/data/t1c_r/"

def resize():
    for item in dirs:
        if os.path.isfile(path+item):
            im = Image.open(path+item)
            f, e = os.path.splitext(path+item)
            print(item)
            print(f)
            imResize = im.resize((256,256), Image.ANTIALIAS)
            imResize.save(out+item+'.png', 'PNG', quality=100)

resize()

# **Data Loader**

In [2]:
import os
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from PIL import UnidentifiedImageError
import PIL

class ABDataset(Dataset):
    def __init__(self, root_a, root_b=None, transform=None):
        self.root_a = root_a
        self.root_b = root_b
        self.transform = transform

        self.a_images = os.listdir(root_a)  # t1 images
        self.b_images = os.listdir(root_b) if root_b else []  # t1c images
        self.length_dataset = max(len(self.a_images), len(self.b_images))
        self.a_len = len(self.a_images)
        self.b_len = len(self.b_images)

    def __len__(self):
        return self.length_dataset

    def __getitem__(self, index):
        if self.root_b is not None:
            a_img = self.a_images[index % self.a_len]
            a_path = os.path.join(self.root_a, a_img)

            while True:
                try:
                    a_img = np.array(Image.open(a_path).convert("RGB"))
                    break
                except (PIL.UnidentifiedImageError, OSError):
                    # skip non-image files
                    index = (index + 1) % self.length_dataset
                    a_img = self.a_images[index % self.a_len]
                    a_path = os.path.join(self.root_a, a_img)

            b_img = self.b_images[index % self.b_len]
            b_path = os.path.join(self.root_b, b_img)

            while True:
                try:
                    b_img = np.array(Image.open(b_path).convert("RGB"))
                    break
                except (PIL.UnidentifiedImageError, OSError):
                    # skip non-image files
                    index = (index + 1) % self.length_dataset
                    b_img = self.b_images[index % self.b_len]
                    b_path = os.path.join(self.root_b, b_img)

            # Apply min-max normalization to images
            a_img = (a_img - np.min(a_img)) / (np.max(a_img) - np.min(a_img))
            b_img = (b_img - np.min(b_img)) / (np.max(b_img) - np.min(b_img))

            if self.transform:
                augmentations = self.transform(image0=a_img, image=b_img)
                a_img = augmentations["image0"]
                b_img = augmentations["image"]

            return a_img, b_img

        elif self.root_b is None:
            a_img = self.a_images[index % self.a_len]
            a_path = os.path.join(self.root_a, a_img)

            while True:
                try:
                    a_img = np.array(Image.open(a_path).convert("RGB"))
                    break
                except (PIL.UnidentifiedImageError, OSError):
                    # skip non-image files
                    index = (index + 1) % self.length_dataset
                    a_img = self.a_images[index % self.a_len]
                    a_path = os.path.join(self.root_a, a_img)

            # Apply min-max normalization to images
            a_img = (a_img - np.min(a_img)) / (np.max(a_img) - np.min(a_img))

            if self.transform:
                augmentations = self.transform(image=a_img)
                a_img = augmentations["image"]

            return a_img


# **Network**

In [3]:
# Generator 

import torch
from cyclegan_tranformer import Generator, Discriminator
gen = Generator(width=256, height=256).to(torch.device("cuda"))
    
print(gen)
tensor = torch.randn((1, 3, 256, 256)).to(torch.device("cuda"))
output = gen(tensor)
print(output.shape)


Generator(
  (patches): ImgPatches(
    (patch_embed): Conv2d(3, 1024, kernel_size=(8, 8), stride=(8, 8))
  )
  (TransformerEncoder): TransformerEncoder(
    (Encoder_Blocks): ModuleList(
      (0): Encoder_Block(
        (ln1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=1024, out_features=3072, bias=False)
          (attention_dropout): Dropout(p=0.0, inplace=False)
          (out): Sequential(
            (0): Linear(in_features=1024, out_features=1024, bias=True)
            (1): Dropout(p=0.0, inplace=False)
          )
        )
        (ln2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
          (droprateout): Dropout(p=0.0, inplace=False)
        )
      )
    )
  )
  (up_blocks): ModuleList(

In [4]:
# Discriminator

import torch
from cyclegan_tranformer import Generator, Discriminator

disc = Discriminator().to(torch.device("cuda"))
tensor = torch.randn((1, 3, 256, 256)).to(torch.device("cuda"))
gen_output = gen(tensor)
print("Generator output shape:", gen_output.shape)
print(disc)
disc_output = disc(gen_output)
print("Discriminator output shape:", disc_output.shape)

Generator output shape: torch.Size([1, 3, 256, 256])
Discriminator(
  (initial): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), padding_mode=reflect)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (down_blocks): ModuleList(
    (0): ConvolutionBlockD(
      (convolution): Sequential(
        (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), padding_mode=reflect)
        (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (1): ConvolutionBlockD(
      (convolution): Sequential(
        (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), padding_mode=reflect)
        (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (2): ConvolutionBlockD(
      (convolution): Sequentia

# **Training**

In [1]:
import os

import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
# from torchsummary import summary
from Utils import save_checkpoint, load_checkpoint
from cyclegan_tranformer import Generator, Discriminator
from torch.utils.tensorboard import SummaryWriter 
import gc
gc.collect()
torch.cuda.empty_cache()

writer = SummaryWriter(log_dir="logs")

TRAIN_DIR = "datasets/brats/train"
path = "Results"
BATCH_SIZE = 8
LEARNING_RATE = 1e-4
LAMBDA_IDENTITY = 10
LAMBDA_CYCLE = 10
NUM_WORKERS = 4
NUM_EPOCHS = 500
LOAD_MODEL = False
SAVE_MODEL = True
IMAGE_HEIGHT = 256
IMAGE_WIDTH = 256
CHECKPOINT_GEN_A = f"{path}/gena.pth.tar"
CHECKPOINT_GEN_B = f"{path}/genb.pth.tar"
CHECKPOINT_DISC_A = f"{path}/disca.pth.tar"
CHECKPOINT_DISC_B = f"{path}/discb.pth.tar"
count = 0


gpu_index = 2  # for 4th GPU


if torch.cuda.is_available():
    DEVICE = torch.device(f"cuda:{gpu_index}")
    print(f"Selected GPU: {gpu_index}")
else:
    DEVICE = torch.device("cpu")
    print("CUDA is not available. Using CPU.")





if not os.path.exists("Results"):
    os.mkdir("Results")
    os.mkdir("Results/Generated from T1C")
    os.mkdir("Results/Generated from T1")

transforms = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
        ],
    additional_targets={"image0": "image"},
)


def train_fn(disc_A, disc_B, gen_A, gen_B, loader, opt_disc, opt_gen, l1, mse,  d_scaler, g_scaler, epoch):
    global count
    avg_dloss = 0
    avg_gloss = 0
    loop = tqdm(loader, leave=True)
    for idx, (a, b) in enumerate(loop):
        a = a.to(DEVICE)
        b = b.to(DEVICE)

        with torch.cuda.amp.autocast():
            fake_a = gen_A(b)
            D_A_real = disc_A(a)
            D_A_fake = disc_A(fake_a.detach())
            D_A_real_loss = mse(D_A_real, torch.ones_like(D_A_real))
            D_A_fake_loss = mse(D_A_fake, torch.zeros_like(D_A_fake))
            D_A_loss = D_A_real_loss + D_A_fake_loss

            fake_b = gen_B(a)
            D_B_real = disc_B(b)
            D_B_fake = disc_B(fake_b.detach())
            D_B_real_loss = mse(D_B_real, torch.ones_like(D_B_real))
            D_B_fake_loss = mse(D_B_fake, torch.zeros_like(D_B_fake))
            D_B_loss = D_B_real_loss + D_B_fake_loss

            D_loss = (D_A_loss + D_B_loss)/2

        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        with torch.cuda.amp.autocast():
            D_A_fake = disc_A(fake_a)
            D_B_fake = disc_B(fake_b)
            loss_G_A = mse(D_A_fake, torch.ones_like(D_A_fake))
            loss_G_B = mse(D_B_fake, torch.ones_like(D_B_fake))

            cycle_b = gen_B(fake_a)
            cycle_a = gen_A(fake_b)
            cycle_b_loss = l1(b, cycle_b)
            cycle_a_loss = l1(a, cycle_a)

            identity_b = gen_B(b)
            identity_a = gen_A(a)
            identity_b_loss = l1(b, identity_b)
            identity_a_loss = l1(a, identity_a)

            G_loss = (
                loss_G_B
                + loss_G_A
                + cycle_b_loss * LAMBDA_CYCLE
                + cycle_a_loss * LAMBDA_CYCLE
                + identity_a_loss * LAMBDA_IDENTITY
                + identity_b_loss * LAMBDA_IDENTITY
            )

            avg_dloss += D_loss.item()
            avg_gloss += G_loss.item()

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()
        
        writer.add_scalar("Discriminator Loss", D_loss.item(), epoch * len(loader) + idx)
        writer.add_scalar("Generator Loss", G_loss.item(), epoch * len(loader) + idx)

        if idx % 100 == 0:
            save_image(fake_a*0.5+0.5, f"{path}/Generated from T1c/{count}_fake.png")
            save_image(fake_b*0.5+0.5, f"{path}/Generated from T1/{count}_fake.png")
            save_image(b*0.5+0.5, f"{path}/Generated from T1c/{count}_real.png")
            save_image(a*0.5+0.5, f"{path}/Generated from T1/{count}_real.png")
            count += 1
        loop.set_postfix(epoch=epoch+1, loss_g=avg_gloss/(idx+1), loss_d=avg_dloss/(idx+1))


def main():
    disc_A = Discriminator().to(DEVICE)
    disc_B = Discriminator().to(DEVICE)
    gen_A = Generator(width=IMAGE_WIDTH, height=IMAGE_HEIGHT).to(DEVICE)
    gen_B = Generator(width=IMAGE_WIDTH, height=IMAGE_HEIGHT).to(DEVICE)

    opt_disc = optim.Adam(
        list(disc_A.parameters()) + list(disc_B.parameters()),
        lr=LEARNING_RATE,
        betas=(0.5, 0.999),
    )

    opt_gen = optim.Adam(
        list(gen_A.parameters()) + list(gen_B.parameters()),
        lr=LEARNING_RATE,
        betas=(0.5, 0.999),
    )

    L1 = nn.L1Loss()
    mse = nn.MSELoss()

    if LOAD_MODEL:
        load_checkpoint(
            CHECKPOINT_GEN_A, gen_A, opt_gen, LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_GEN_B, gen_B, opt_gen, LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_DISC_A, disc_A, opt_disc, LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_DISC_B, disc_B, opt_disc, LEARNING_RATE,
        )

    dataset = ABDataset(
        root_a=TRAIN_DIR+"/t1c_r", root_b=TRAIN_DIR+"/t1n_r", transform=transforms
    )

    loader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True
    )

    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()

    for epoch in range(NUM_EPOCHS):
        train_fn(disc_A, disc_B, gen_A, gen_B, loader, opt_disc, opt_gen, L1, mse, d_scaler, g_scaler, epoch)

        if SAVE_MODEL:
            save_checkpoint(gen_A, opt_gen, filename=CHECKPOINT_GEN_A)
            save_checkpoint(gen_B, opt_gen, filename=CHECKPOINT_GEN_B)
            save_checkpoint(disc_A, opt_disc, filename=CHECKPOINT_DISC_A)
            save_checkpoint(disc_B, opt_disc, filename=CHECKPOINT_DISC_B)
            
    writer.close()


2023-12-19 11:43:42.758876: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.10.1


Selected GPU: 3


In [6]:
import os

import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
# from torchsummary import summary
from Utils import save_checkpoint, load_checkpoint
from cyclegan_tranformer import Generator, Discriminator
from torch.utils.tensorboard import SummaryWriter 
import gc
gc.collect()
torch.cuda.empty_cache()

writer = SummaryWriter(log_dir="logs")

TRAIN_DIR = "datasets/brats/train"
path = "Results"
BATCH_SIZE = 8
LEARNING_RATE = 1e-4
LAMBDA_IDENTITY = 10
LAMBDA_CYCLE = 10
NUM_WORKERS = 4
NUM_EPOCHS = 500
LOAD_MODEL = False
SAVE_MODEL = True
IMAGE_HEIGHT = 256
IMAGE_WIDTH = 256
CHECKPOINT_GEN_A = f"{path}/gena.pth.tar"
CHECKPOINT_GEN_B = f"{path}/genb.pth.tar"
CHECKPOINT_DISC_A = f"{path}/disca.pth.tar"
CHECKPOINT_DISC_B = f"{path}/discb.pth.tar"
count = 0


gpu_index = 2  # for 4th GPU


if torch.cuda.is_available():
    DEVICE = torch.device(f"cuda:{gpu_index}")
    print(f"Selected GPU: {gpu_index}")
else:
    DEVICE = torch.device("cpu")
    print("CUDA is not available. Using CPU.")





if not os.path.exists("Results"):
    os.mkdir("Results")
    os.mkdir("Results/Generated from T1C")
    os.mkdir("Results/Generated from T1")

transforms = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
        ],
    additional_targets={"image0": "image"},
)


def train_fn(disc_A, disc_B, gen_A, gen_B, loader, opt_disc, opt_gen, l1, mse,  d_scaler, g_scaler, epoch):
    global count
    avg_dloss = 0
    avg_gloss = 0
    loop = tqdm(loader, leave=True)
    for idx, (a, b) in enumerate(loop):
        a = a.to(DEVICE)
        b = b.to(DEVICE)

        with torch.cuda.amp.autocast():
            fake_a = gen_A(b)
            D_A_real = disc_A(a)
            D_A_fake = disc_A(fake_a.detach())
            D_A_real_loss = mse(D_A_real, torch.ones_like(D_A_real))
            D_A_fake_loss = mse(D_A_fake, torch.zeros_like(D_A_fake))
            D_A_loss = D_A_real_loss + D_A_fake_loss

            fake_b = gen_B(a)
            D_B_real = disc_B(b)
            D_B_fake = disc_B(fake_b.detach())
            D_B_real_loss = mse(D_B_real, torch.ones_like(D_B_real))
            D_B_fake_loss = mse(D_B_fake, torch.zeros_like(D_B_fake))
            D_B_loss = D_B_real_loss + D_B_fake_loss

            D_loss = (D_A_loss + D_B_loss)/2

        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        with torch.cuda.amp.autocast():
            D_A_fake = disc_A(fake_a)
            D_B_fake = disc_B(fake_b)
            loss_G_A = mse(D_A_fake, torch.ones_like(D_A_fake))
            loss_G_B = mse(D_B_fake, torch.ones_like(D_B_fake))

            cycle_b = gen_B(fake_a)
            cycle_a = gen_A(fake_b)
            cycle_b_loss = l1(b, cycle_b)
            cycle_a_loss = l1(a, cycle_a)

            identity_b = gen_B(b)
            identity_a = gen_A(a)
            identity_b_loss = l1(b, identity_b)
            identity_a_loss = l1(a, identity_a)

            G_loss = (
                loss_G_B
                + loss_G_A
                + cycle_b_loss * LAMBDA_CYCLE
                + cycle_a_loss * LAMBDA_CYCLE
                + identity_a_loss * LAMBDA_IDENTITY
                + identity_b_loss * LAMBDA_IDENTITY
            )

            avg_dloss += D_loss.item()
            avg_gloss += G_loss.item()

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()
        
        writer.add_scalar("Discriminator Loss", D_loss.item(), epoch * len(loader) + idx)
        writer.add_scalar("Generator Loss", G_loss.item(), epoch * len(loader) + idx)

        if idx % 100 == 0:
            save_image(fake_a*0.5+0.5, f"{path}/Generated from T1c/{count}_fake.png")
            save_image(fake_b*0.5+0.5, f"{path}/Generated from T1/{count}_fake.png")
            save_image(b*0.5+0.5, f"{path}/Generated from T1c/{count}_real.png")
            save_image(a*0.5+0.5, f"{path}/Generated from T1/{count}_real.png")
            count += 1
        loop.set_postfix(epoch=epoch+1, loss_g=avg_gloss/(idx+1), loss_d=avg_dloss/(idx+1))


def main():
    disc_A = Discriminator().to(DEVICE)
    disc_B = Discriminator().to(DEVICE)
    gen_A = Generator(width=IMAGE_WIDTH, height=IMAGE_HEIGHT).to(DEVICE)
    gen_B = Generator(width=IMAGE_WIDTH, height=IMAGE_HEIGHT).to(DEVICE)

    opt_disc = optim.Adam(
        list(disc_A.parameters()) + list(disc_B.parameters()),
        lr=LEARNING_RATE,
        betas=(0.5, 0.999),
    )

    opt_gen = optim.Adam(
        list(gen_A.parameters()) + list(gen_B.parameters()),
        lr=LEARNING_RATE,
        betas=(0.5, 0.999),
    )

    L1 = nn.L1Loss()
    mse = nn.MSELoss()

    if LOAD_MODEL:
        load_checkpoint(
            CHECKPOINT_GEN_A, gen_A, opt_gen, LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_GEN_B, gen_B, opt_gen, LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_DISC_A, disc_A, opt_disc, LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_DISC_B, disc_B, opt_disc, LEARNING_RATE,
        )

    dataset = ABDataset(
        root_a=TRAIN_DIR+"/t1c_r", root_b=TRAIN_DIR+"/t1n_r", transform=transforms
    )

    loader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True
    )

    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()

    for epoch in range(NUM_EPOCHS):
        train_fn(disc_A, disc_B, gen_A, gen_B, loader, opt_disc, opt_gen, L1, mse, d_scaler, g_scaler, epoch)

        if SAVE_MODEL:
            save_checkpoint(gen_A, opt_gen, filename=CHECKPOINT_GEN_A)
            save_checkpoint(gen_B, opt_gen, filename=CHECKPOINT_GEN_B)
            save_checkpoint(disc_A, opt_disc, filename=CHECKPOINT_DISC_A)
            save_checkpoint(disc_B, opt_disc, filename=CHECKPOINT_DISC_B)
            
    writer.close()


Selected GPU: 2


In [8]:
# Train Normalization

main()


100%|██████████| 126/126 [01:12<00:00,  1.73it/s, epoch=1, loss_d=0.232, loss_g=3.27]
100%|██████████| 126/126 [01:16<00:00,  1.65it/s, epoch=2, loss_d=0.11, loss_g=1.54] 
100%|██████████| 126/126 [01:17<00:00,  1.63it/s, epoch=3, loss_d=0.0485, loss_g=1.68]
100%|██████████| 126/126 [01:17<00:00,  1.63it/s, epoch=4, loss_d=0.0695, loss_g=1.75]
100%|██████████| 126/126 [01:16<00:00,  1.64it/s, epoch=5, loss_d=0.0617, loss_g=1.65]
100%|██████████| 126/126 [01:15<00:00,  1.66it/s, epoch=6, loss_d=0.143, loss_g=1.19]
100%|██████████| 126/126 [01:15<00:00,  1.67it/s, epoch=7, loss_d=0.237, loss_g=0.572]
100%|██████████| 126/126 [01:15<00:00,  1.67it/s, epoch=8, loss_d=0.232, loss_g=0.584]
100%|██████████| 126/126 [01:15<00:00,  1.67it/s, epoch=9, loss_d=0.227, loss_g=0.594]
100%|██████████| 126/126 [01:15<00:00,  1.67it/s, epoch=10, loss_d=0.223, loss_g=0.603]
100%|██████████| 126/126 [01:15<00:00,  1.67it/s, epoch=11, loss_d=0.219, loss_g=0.611]
100%|██████████| 126/126 [01:15<00:00,  1.67

In [4]:

main()


100%|██████████| 126/126 [01:13<00:00,  1.73it/s, epoch=1, loss_d=0.367, loss_g=4.2] 
100%|██████████| 126/126 [01:15<00:00,  1.66it/s, epoch=2, loss_d=0.377, loss_g=2.6] 
100%|██████████| 126/126 [01:16<00:00,  1.65it/s, epoch=3, loss_d=0.377, loss_g=2.58]
100%|██████████| 126/126 [01:16<00:00,  1.65it/s, epoch=4, loss_d=0.436, loss_g=2.2] 
100%|██████████| 126/126 [01:16<00:00,  1.65it/s, epoch=5, loss_d=0.447, loss_g=1.95]
100%|██████████| 126/126 [01:16<00:00,  1.65it/s, epoch=6, loss_d=0.443, loss_g=1.88]
100%|██████████| 126/126 [01:16<00:00,  1.65it/s, epoch=7, loss_d=0.433, loss_g=1.9] 
100%|██████████| 126/126 [01:16<00:00,  1.65it/s, epoch=8, loss_d=0.443, loss_g=1.85]
100%|██████████| 126/126 [01:16<00:00,  1.65it/s, epoch=9, loss_d=0.442, loss_g=1.76]
100%|██████████| 126/126 [01:16<00:00,  1.65it/s, epoch=10, loss_d=0.451, loss_g=1.7] 
100%|██████████| 126/126 [01:16<00:00,  1.65it/s, epoch=11, loss_d=0.417, loss_g=1.99]
100%|██████████| 126/126 [01:16<00:00,  1.65it/s, ep

**Run Training**

In [15]:

main()


100%|██████████| 126/126 [01:12<00:00,  1.75it/s, epoch=1, loss_d=0.384, loss_g=4.77]
100%|██████████| 126/126 [01:15<00:00,  1.66it/s, epoch=2, loss_d=0.37, loss_g=2.67] 
100%|██████████| 126/126 [01:16<00:00,  1.65it/s, epoch=3, loss_d=0.439, loss_g=2.29]
100%|██████████| 126/126 [01:16<00:00,  1.65it/s, epoch=4, loss_d=0.435, loss_g=2.11]
100%|██████████| 126/126 [01:16<00:00,  1.65it/s, epoch=5, loss_d=0.445, loss_g=1.96]
100%|██████████| 126/126 [01:16<00:00,  1.65it/s, epoch=6, loss_d=0.432, loss_g=1.95]
100%|██████████| 126/126 [01:16<00:00,  1.65it/s, epoch=7, loss_d=0.439, loss_g=1.8] 
100%|██████████| 126/126 [01:16<00:00,  1.65it/s, epoch=8, loss_d=0.438, loss_g=1.79]
100%|██████████| 126/126 [01:16<00:00,  1.65it/s, epoch=9, loss_d=0.437, loss_g=1.79]
100%|██████████| 126/126 [01:16<00:00,  1.65it/s, epoch=10, loss_d=0.386, loss_g=2.13]
100%|██████████| 126/126 [01:16<00:00,  1.65it/s, epoch=11, loss_d=0.402, loss_g=1.92]
100%|██████████| 126/126 [01:16<00:00,  1.65it/s, ep

# **Test**

In [None]:
import time
import gc
import torch
import math
import albumentations as A
import numpy as np

from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from tqdm import tqdm
# from torchsummary import summary
# from torchmetrics import PeakSignalNoiseRatio
from skimage import color
from skimage.metrics import structural_similarity
from cyclegan_tranformer import Generator
from Utils import save_checkpoint, load_checkpoint

# gc.collect()
# torch.cuda.empty_cache()

dataset_name = "brats"  

path = "Results"
checkpoint = "Results/genb.pth.tar"
save_path = f"Results/Testing/{dataset_name}"
TEST_DIR = 'datasets/brats/test/'
IMAGE_WIDTH = 256
IMAGE_HEIGHT = 256


gpu_index = 3  # for 4th GPU

if torch.cuda.is_available():
    DEVICE = torch.device(f"cuda:{gpu_index}")
    print(f"Selected GPU: {gpu_index}")
else:
    DEVICE = torch.device("cpu")
    print("CUDA is not available. Using CPU.")



def masking(a, b):
    l_top = l_bottom = 0
    a = a[0]
    b = b[0]

    for i in range(a.shape[1]):
        if torch.sum(a[:, i, :]) != 0:
            break
        l_top += 1

    for i in range(a.shape[1]):
        if torch.sum(a[:, a.shape[1] - i - 1, :]) != 0:
            break
        l_bottom += 1

    b[:, :l_top, :] = 0
    b[:, b.shape[1] - l_bottom:, :] = 0

    return a, b


def PSNR_SSIM(orig_img, gen_img):
    gray_orig_img = color.rgb2gray(orig_img)
    gray_gen_img = color.rgb2gray(gen_img)

    mse = np.mean((gray_orig_img - gray_gen_img) ** 2)
    if mse == 0:
        psnr = 100
    else:
        max_pixel = 1.0
        psnr = 20 * math.log10(max_pixel / math.sqrt(mse))

    ssim = structural_similarity(gray_orig_img, gray_gen_img, multichannel=False, data_range=1.0)

    return round(psnr, 3), round(ssim, 3)



gen = Generator(width=IMAGE_WIDTH, height=IMAGE_HEIGHT).to(DEVICE)
# summary(gen, (3, 256, 256))

load_checkpoint(checkpoint, gen, None, None)
print("Checkpoint loaded")

transforms = A.Compose(
    [
        A.Resize(width=IMAGE_WIDTH, height=IMAGE_HEIGHT),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
    ]
)

val_dataset = ABDataset(
    # root_a=TEST_DIR, transform=transforms
    root_a=TEST_DIR+"/t1c_r", transform=transforms
)

val_loader = DataLoader(
    val_dataset,
    batch_size=1,
    shuffle=False,
    pin_memory=True,
)

loop = tqdm(val_loader, leave=True)
psnr_values = []
ssim_values = []
mse_values = []

print("val_loader", len(val_loader))

start = time.time()

# print("Before loop")
for idx, image in enumerate(loop):
    image = image.to(DEVICE)
    # print("Inside loop1")
    with torch.cuda.amp.autocast():
        # print("Inside loop2")
        # print("idx", idx)
        
        gen_image = gen(image)
        image, gen_image = masking(image*0.5+0.5, gen_image*0.5+0.5)

        save_image(gen_image, f"{save_path}/{idx}_fake.png")
        save_image(image, f"{save_path}/{idx}_real.png")

        image = image.permute(1, 2, 0).detach().cpu().numpy()
        gen_image = gen_image.permute(1, 2, 0).detach().cpu().numpy()

        psnr_values.append(PSNR_SSIM(image, gen_image)[0])
        ssim_values.append(PSNR_SSIM(image, gen_image)[1])
        
        # mse = torch.mean((gen_image - image)**2).item()
        image_tensor = torch.from_numpy(image).permute(2, 0, 1).to(DEVICE)
        gen_image_tensor = torch.from_numpy(gen_image).permute(2, 0, 1).to(DEVICE)
        mse = torch.mean((gen_image_tensor - image_tensor)**2).item()
        mse_values.append(mse)

end = time.time()

metrics = [
    round(sum(psnr_values) / len(val_loader), 3),
    round(sum(ssim_values) / len(val_loader), 3),
    round((end - start) / len(val_loader), 3)
]

f = open(f"{path}/Results {dataset_name}.txt", 'w')
f.write(f"Testing PSNR :{metrics[0]} dB\n")
f.write(f"Testing SSIM :{metrics[1]}\n")
f.write(f"Single image time: {metrics[2]} seconds\n")

print("Testing PSNR" ,metrics[0])
print("Testing SSIM" ,metrics[1])
mean_mse = round(sum(mse_values) / len(mse_values), 3)
print(f"Mean MSE: {mean_mse}")
