In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#@title REQ
import os
import sys
import cv2
import pathlib
from torch.utils.data import DataLoader
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as dataset
import torchvision.utils as vutils
import matplotlib.pyplot as plt

# Set random seed for reproducibility
torch.manual_seed(42)


# Difine Variables

BATCH_SIZE = 128
IMAGE_CHANNEL = 3
Z_DIM = 100                 # Random vector dimension
G_HIDDEN = 64
X_DIM = 64                  # The width/height of the generated images.
D_HIDDEN = 64
EPOCH_NUM = 50000
REAL_LABEL = 1
FAKE_LABEL = 0
lr = 2e-4
ngpu = 1                   # Number of GPUs available. Use 0 for CPU mode


device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")


In [None]:
#@title Lateral Data
import os
import torch
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image
import numpy as np

class OODImageDataset(Dataset):
    def __init__(self, root, transform=None, target_transform=None, target_class='LATERAL'):
        self.root = root
        self.transform = transform
        self.target_transform = target_transform
        self.samples = []  # List to store sample paths and labels



        # Ensure that the target_class is valid
        if target_class not in ['LATERAL']:
            raise ValueError("target_class must be 'LATERAL'.")

        for class_name in os.listdir(root):
            if class_name == target_class:
                class_dir = os.path.join(root, class_name)
                if os.path.isdir(class_dir):
                    class_label = class_name
                    for img_name in os.listdir(class_dir):
                        img_path = os.path.join(class_dir, img_name)
                        if not img_name.startswith('.DS_Store'):  # Exclude .DS_Store files
                            self.samples.append((img_path, class_label))



    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, class_label = self.samples[idx]

        # Load the image as a PIL Image
        image = Image.open(img_path)

        # Convert to a 3-channel image by duplicating the single channel
        image = np.array(image)  # Convert PIL Image to NumPy array
        image = np.stack([image] * 3, axis=-1)  # Duplicate the single channel to create three channels

        # Convert NumPy array back to PIL Image
        image = Image.fromarray(image)

        if self.transform:
            image = self.transform(image)

        return image, class_label

# Define the path to your chest X-ray dataset
DATA_PATH = "/content/drive/MyDrive/HKU/Projects/GAN/chest_xray/chest_xray/train"
X_DIM = (64,64)  # You can adjust this to your desired image size
BATCH_SIZE = 128  # You can adjust this batch size as needed

# Data preprocessing
transform = transforms.Compose([
    transforms.Resize(X_DIM),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Create the ImageDataset for the 'normal' class
dataset_ood = OODImageDataset(root=DATA_PATH, transform=transform, target_class='LATERAL')
dataloader_ood = torch.utils.data.DataLoader(dataset_ood, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)


In [None]:
#@title Frontal Data
import os
import torch
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image
import numpy as np

class ImageDataset(Dataset):
    def __init__(self, root, transform=None, target_transform=None, target_class='Frontal'):
        self.root = root
        self.transform = transform
        self.target_transform = target_transform
        self.samples = []  # List to store sample paths and labels

        # Ensure that the target_class is valid
        if target_class not in ['NORMAL', 'PNEUMONIA', 'LATERAL']:
            raise ValueError("target_class must be either 'NORMAL' or 'PNEUMONIA' or 'LATERAL'.")

        # Loop through the subdirectories (classes) in the dataset root
        for class_name in os.listdir(root):
            if class_name == target_class:
                class_dir = os.path.join(root, class_name)
                if os.path.isdir(class_dir):
                    class_label = class_name  # You can use class_to_idx.get(class_name, class_name) for labels
                    for img_name in os.listdir(class_dir):
                        img_path = os.path.join(class_dir, img_name)
                        if not img_name.startswith('.DS_Store'):  # Exclude .DS_Store files
                            self.samples.append((img_path, class_label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, class_label = self.samples[idx]

        # Load the image as a PIL Image
        image = Image.open(img_path)

        # Convert to a 3-channel image by duplicating the single channel
        image = np.array(image)  # Convert PIL Image to NumPy array
        image = np.stack([image] * 3, axis=-1)  # Duplicate the single channel to create three channels

        # Convert NumPy array back to PIL Image
        image = Image.fromarray(image)

        if self.transform:
            image = self.transform(image)

        return image, class_label

# Define the path to your chest X-ray dataset
DATA_PATH = "/content/drive/MyDrive/HKU/Projects/GAN/chest_xray/chest_xray/train"
X_DIM = (64, 64)  # You can adjust this to your desired image size
BATCH_SIZE = 128  # You can adjust this batch size as needed

# Data preprocessing
transform = transforms.Compose([
    transforms.Resize(X_DIM),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Create the ImageDataset for the 'normal' class
dataset = ImageDataset(root=DATA_PATH, transform=transform, target_class='NORMAL')

# Dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)


In [None]:
#@title GAN
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # input layer
            nn.ConvTranspose2d(Z_DIM, G_HIDDEN * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(G_HIDDEN * 8),
            nn.ReLU(True),
            # 1st hidden layer
            nn.ConvTranspose2d(G_HIDDEN * 8, G_HIDDEN * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(G_HIDDEN * 4),
            nn.ReLU(True),
            # 2nd hidden layer
            nn.ConvTranspose2d(G_HIDDEN * 4, G_HIDDEN * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(G_HIDDEN * 2),
            nn.ReLU(True),
            # 3rd hidden layer
            nn.ConvTranspose2d(G_HIDDEN * 2, G_HIDDEN, 4, 2, 1, bias=False),
            nn.BatchNorm2d(G_HIDDEN),
            nn.ReLU(True),
            # output layer
            nn.ConvTranspose2d(G_HIDDEN, IMAGE_CHANNEL, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # 1st layer
            nn.Conv2d(IMAGE_CHANNEL, D_HIDDEN, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 2nd layer
            nn.Conv2d(D_HIDDEN, D_HIDDEN * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(D_HIDDEN * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # 3rd layer
            nn.Conv2d(D_HIDDEN * 2, D_HIDDEN * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(D_HIDDEN * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 4th layer
            nn.Conv2d(D_HIDDEN * 4, D_HIDDEN * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(D_HIDDEN * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # output layer
            nn.Conv2d(D_HIDDEN * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input).view(-1, 1).squeeze(1)


# Create the generator
netG = Generator().to(device)
netG.apply(weights_init)
print(netG)

# Create the discriminator
netD = Discriminator().to(device)
netD.apply(weights_init)
print(netD)

# Initialize BCELoss function
criterion = nn.BCELoss()

# Create batch of latent vectors that I will use to visualize the progression of the generator
viz_noise = torch.randn(BATCH_SIZE, Z_DIM, 1, 1, device=device)

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))


Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)


In [None]:
#@title Optimization
import torch
import torch.optim as optim
import torch.nn as nn
from skimage.io import imread
from skimage.transform import resize
import os
from scipy.stats import kstest
import numpy as np

checkpoint_dir = "/content/drive/MyDrive/HKU/Projects/GAN/checkpoints"
checkpoint_file = "checkpoint_epoch_999.pt"
checkpoint_path = os.path.join(checkpoint_dir, checkpoint_file)
checkpoint = torch.load(checkpoint_path)


netG.load_state_dict(checkpoint['netG_state_dict'])
optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
G_losses = checkpoint['G_losses']
netG.to(device)
netG.eval()  # Set the generator network to evaluation mode

def optimize_z(netG, X, lr=0.001, loss_threshold=0.15, extra_steps_after_threshold=95):
    z_vectors = []

    for i in range(X.size(0)):  # Loop over each image in X
        z = torch.zeros(1, Z_DIM, 1, 1, device=device, requires_grad=True)  # Initialize z
        optimizer = optim.Adam([z], lr=lr)
        loss_fn = nn.MSELoss()
        loss_item = float('inf')
        step = 0
        steps_after_threshold = 0

        while steps_after_threshold < extra_steps_after_threshold:
            optimizer.zero_grad()
            G_z = netG(z)
            loss = loss_fn(G_z, X[i].unsqueeze(0))  # Compare with the i-th image in X
            loss_item = loss.item()
            loss.backward()                         # gradients of the loss
            optimizer.step()                        # Update the latent vector

            # Perform KS test
            z_flat = z.detach().cpu().numpy().flatten()  # Flatten and move z to CPU for KS test
            ks_statistic, p_value = kstest(z_flat, 'norm')


            if step % 100 == 0:
                print(f"Image {i}, Step {step}, Loss: {loss_item}, KS Statistic: {ks_statistic}, P-value: {p_value}")

            step += 1


            if loss_item <= loss_threshold:
                steps_after_threshold += 1
            else:
                steps_after_threshold = 0

        z_vectors.append(z.detach())

    return torch.cat(z_vectors, dim=0)


def load_image(image_path, common_size=(64, 64), image_channels=3):
    image = imread(image_path)
    image_resized = resize(image, common_size, anti_aliasing=True)

    if image_resized.ndim == 2: # If the image is grayscale, convert to 3 channels
        image_resized = np.stack((image_resized,)*3, axis=-1)

    image_normalized = (image_resized - 0.5) / 0.5 # Normalize to [-1, 1]
    image_tensor = torch.tensor(image_normalized.transpose((2, 0, 1)), dtype=torch.float32)

    return image_tensor.unsqueeze(0) # Add a batch dimension

# Load all images from the folder
folder_path = "/content/drive/MyDrive/HKU/Projects/GAN/test"
image_files = [f for f in os.listdir(folder_path) if f.endswith(('.jpg', '.png'))]

if image_files:
    X = []
    for image_file in image_files:
        image_path = os.path.join(folder_path, image_file)
        X.append(load_image(image_path, common_size=(64, 64), image_channels=3))

    X = torch.cat(X, dim=0).to(device)
    optimized_z = optimize_z(netG, X)
else:
    print("No images found in the folder.")

Image 0, Step 0, Loss: 0.5773603916168213, KS Statistic: 0.4996010578600262, P-value: 1.328747138199462e-23
Image 0, Step 100, Loss: 0.3942347764968872, KS Statistic: 0.45909284246226184, P-value: 8.466900970210012e-20
Image 0, Step 200, Loss: 0.3719863295555115, KS Statistic: 0.432573108206318, P-value: 1.5931896129706993e-17
Image 0, Step 300, Loss: 0.3567718267440796, KS Statistic: 0.3969975850984505, P-value: 9.988325698726413e-15
Image 0, Step 400, Loss: 0.3419489860534668, KS Statistic: 0.3818965008209943, P-value: 1.2649613298633668e-13
Image 0, Step 500, Loss: 0.32572853565216064, KS Statistic: 0.35407996852742923, P-value: 1.0121753728595365e-11
Image 0, Step 600, Loss: 0.31259244680404663, KS Statistic: 0.3359714394568195, P-value: 1.43654744002546e-10
Image 0, Step 700, Loss: 0.2906751334667206, KS Statistic: 0.31367909456678555, P-value: 3.0462579775869196e-09
Image 0, Step 800, Loss: 0.2276442050933838, KS Statistic: 0.285031466777156, P-value: 1.10408669653609e-07
Image 0

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score


# Calculate metrics
accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred)
recall = recall_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred)

# Print metrics
print("Accuracy: ", accuracy)
print("Precision: ", precision)
print("Recall: ", recall)
print("F1 Score: ", f1)