In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.utils import save_image

In [3]:
class ResnetBlock(nn.Module):
    def __init__(self, channels):
        super(ResnetBlock, self).__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(channels),
        )

    def forward(self, x):
        return x + self.conv_block(x)

class Generator(nn.Module):
    def __init__(self, input_channels=1, output_channels=1, num_res_blocks=9):
        super(Generator, self).__init__()
        model = [
            nn.Conv2d(input_channels, 64, kernel_size=7, stride=1, padding=3),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        ]

        in_channels = 64
        out_channels = 128
        for _ in range(2):
            model += [
                nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
                nn.InstanceNorm2d(out_channels),
                nn.ReLU(inplace=True)
            ]
            in_channels = out_channels
            out_channels *= 2

        for _ in range(num_res_blocks):
            model += [ResnetBlock(in_channels)]

        out_channels = in_channels // 2
        for _ in range(2):
            model += [
                nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(out_channels),
                nn.ReLU(inplace=True)
            ]
            in_channels = out_channels
            out_channels //= 2

        model += [
            nn.Conv2d(in_channels, output_channels, kernel_size=7, stride=1, padding=3),
            nn.Tanh()
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

In [4]:
class Discriminator(nn.Module):
    def __init__(self, input_channels=1):
        super(Discriminator, self).__init__()
        model = [
            nn.Conv2d(input_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        ]

        in_channels = 64
        out_channels = 128
        for _ in range(3):
            model += [
                nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
                nn.InstanceNorm2d(out_channels),
                nn.LeakyReLU(0.2, inplace=True)
            ]
            in_channels = out_channels
            out_channels *= 2

        model += [
            nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1),
            nn.Sigmoid()
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

In [5]:
def adversarial_loss(y_pred, y_true):
    return nn.functional.mse_loss(y_pred, y_true)

def cycle_consistency_loss(original, reconstructed):
    return nn.functional.l1_loss(original, reconstructed)

def identity_loss(real, generated):
    return nn.functional.l1_loss(real, generated)

In [6]:
!pip install kagglehub
import kagglehub

path = kagglehub.dataset_download("paultimothymooney/chest-xray-pneumonia")

print("Path to dataset files:", path)

Downloading from https://www.kaggle.com/api/v1/datasets/download/paultimothymooney/chest-xray-pneumonia?dataset_version_number=2...


100%|██████████| 2.29G/2.29G [00:10<00:00, 224MB/s]

Extracting files...





Path to dataset files: /root/.cache/kagglehub/datasets/paultimothymooney/chest-xray-pneumonia/versions/2


In [7]:
import os

dataset_path = "/root/.cache/kagglehub/datasets/paultimothymooney/chest-xray-pneumonia/versions/2/chest_xray"

print("Contents of chest_xray:")
print(os.listdir(dataset_path))

Contents of chest_xray:
['__MACOSX', 'val', 'train', 'chest_xray', 'test']


In [8]:
for root, dirs, files in os.walk(dataset_path):
    print(f"📂 {root}")
    for d in dirs:
        print(f"  ├── {d}")


📂 /root/.cache/kagglehub/datasets/paultimothymooney/chest-xray-pneumonia/versions/2/chest_xray
  ├── __MACOSX
  ├── val
  ├── train
  ├── chest_xray
  ├── test
📂 /root/.cache/kagglehub/datasets/paultimothymooney/chest-xray-pneumonia/versions/2/chest_xray/__MACOSX
  ├── chest_xray
📂 /root/.cache/kagglehub/datasets/paultimothymooney/chest-xray-pneumonia/versions/2/chest_xray/__MACOSX/chest_xray
  ├── val
  ├── train
  ├── test
📂 /root/.cache/kagglehub/datasets/paultimothymooney/chest-xray-pneumonia/versions/2/chest_xray/__MACOSX/chest_xray/val
  ├── NORMAL
  ├── PNEUMONIA
📂 /root/.cache/kagglehub/datasets/paultimothymooney/chest-xray-pneumonia/versions/2/chest_xray/__MACOSX/chest_xray/val/NORMAL
📂 /root/.cache/kagglehub/datasets/paultimothymooney/chest-xray-pneumonia/versions/2/chest_xray/__MACOSX/chest_xray/val/PNEUMONIA
📂 /root/.cache/kagglehub/datasets/paultimothymooney/chest-xray-pneumonia/versions/2/chest_xray/__MACOSX/chest_xray/train
  ├── NORMAL
  ├── PNEUMONIA
📂 /root/.cache/kag

In [9]:
chest_xray_path = "/root/.cache/kagglehub/datasets/paultimothymooney/chest-xray-pneumonia/versions/2/chest_xray"
train_path = os.path.join(chest_xray_path, "train")

In [10]:
import os

dataset_path = "/root/.cache/kagglehub/datasets/paultimothymooney/chest-xray-pneumonia/versions/2/chest_xray/chest_xray"

train_path = os.path.join(dataset_path, "train")
test_path = os.path.join(dataset_path, "test")
val_path = os.path.join(dataset_path, "val")

train_normal_path = os.path.join(train_path, "NORMAL")
train_pneumonia_path = os.path.join(train_path, "PNEUMONIA")

test_normal_path = os.path.join(test_path, "NORMAL")
test_pneumonia_path = os.path.join(test_path, "PNEUMONIA")

val_normal_path = os.path.join(val_path, "NORMAL")
val_pneumonia_path = os.path.join(val_path, "PNEUMONIA")

print(f"Train Normal: {len(os.listdir(train_normal_path))} images")
print(f"Train Pneumonia: {len(os.listdir(train_pneumonia_path))} images")
print(f"Test Normal: {len(os.listdir(test_normal_path))} images")
print(f"Test Pneumonia: {len(os.listdir(test_pneumonia_path))} images")
print(f"Val Normal: {len(os.listdir(val_normal_path))} images")
print(f"Val Pneumonia: {len(os.listdir(val_pneumonia_path))} images")

Train Normal: 1342 images
Train Pneumonia: 3876 images
Test Normal: 234 images
Test Pneumonia: 390 images
Val Normal: 9 images
Val Pneumonia: 9 images


In [11]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

class ChestXrayDataset(Dataset):
    def __init__(self, root_A, root_B, transform=None):
        self.files_A = sorted([os.path.join(root_A, f) for f in os.listdir(root_A) if f.endswith('.jpeg') or f.endswith('.png')])
        self.files_B = sorted([os.path.join(root_B, f) for f in os.listdir(root_B) if f.endswith('.jpeg') or f.endswith('.png')])
        self.transform = transform

    def __len__(self):
        return min(len(self.files_A), len(self.files_B))

    def __getitem__(self, index):
        img_A = Image.open(self.files_A[index]).convert("L")
        img_B = Image.open(self.files_B[index]).convert("L")

        if self.transform:
            img_A = self.transform(img_A)
            img_B = self.transform(img_B)

        return img_A, img_B

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

dataset = ChestXrayDataset(train_normal_path, train_pneumonia_path, transform=transform)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=4, pin_memory=True)

for normal, pneumonia in dataloader:
    print(f"Batch size: {normal.shape}, {pneumonia.shape}")
    break

Batch size: torch.Size([16, 1, 256, 256]), torch.Size([16, 1, 256, 256])


In [13]:
import os
import torch

os.makedirs("checkpoints", exist_ok=True)
torch.save(G_AB.state_dict(), f"checkpoints/G_AB_{epoch}.pth")
torch.save(G_BA.state_dict(), f"checkpoints/G_BA_{epoch}.pth")


In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

G_AB = Generator().to(device)
G_BA = Generator().to(device)
D_A = Discriminator().to(device)
D_B = Discriminator().to(device)

optimizer_G = torch.optim.Adam(list(G_AB.parameters()) + list(G_BA.parameters()), lr=0.0001, betas=(0.5, 0.999))
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=0.0001, betas=(0.5, 0.999))

num_epochs = 200
for epoch in range(num_epochs):
    for i, (real_A, real_B) in enumerate(dataloader):
        real_A, real_B = real_A.to(device), real_B.to(device)

        fake_B = G_AB(real_A)
        fake_A = G_BA(real_B)

        loss_GAN_AB = adversarial_loss(D_B(fake_B), torch.ones_like(D_B(fake_B)))
        loss_GAN_BA = adversarial_loss(D_A(fake_A), torch.ones_like(D_A(fake_A)))

        loss_cycle_A = cycle_consistency_loss(real_A, G_BA(fake_B))
        loss_cycle_B = cycle_consistency_loss(real_B, G_AB(fake_A))

        loss_identity_A = identity_loss(real_A, G_BA(real_A))
        loss_identity_B = identity_loss(real_B, G_AB(real_B))

        loss_G = loss_GAN_AB + loss_GAN_BA + 10 * (loss_cycle_A + loss_cycle_B) + 5 * (loss_identity_A + loss_identity_B)

        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()

        optimizer_D_A.zero_grad()
        loss_D_A = adversarial_loss(D_A(real_A), torch.ones_like(D_A(real_A))) + adversarial_loss(D_A(fake_A.detach()), torch.zeros_like(D_A(fake_A.detach())))
        loss_D_A.backward()
        optimizer_D_A.step()

        optimizer_D_B.zero_grad()
        loss_D_B = adversarial_loss(D_B(real_B), torch.ones_like(D_B(real_B))) + adversarial_loss(D_B(fake_B.detach()), torch.zeros_like(D_B(fake_B.detach())))
        loss_D_B.backward()
        optimizer_D_B.step()

        if i % 10 == 0:
            print(f"Epoch [{epoch}/{num_epochs}], Step [{i}/{len(dataloader)}], Loss_G: {loss_G.item()}, Loss_D_A: {loss_D_A.item()}, Loss_D_B: {loss_D_B.item()}")

    if epoch % 10 == 0:
        torch.save(G_AB.state_dict(), f"checkpoints/G_AB_{epoch}.pth")
        torch.save(G_BA.state_dict(), f"checkpoints/G_BA_{epoch}.pth")

Epoch [0/200], Step [0/84], Loss_G: 14.444938659667969, Loss_D_A: 0.5296505689620972, Loss_D_B: 0.5321570634841919
Epoch [0/200], Step [10/84], Loss_G: 5.012276649475098, Loss_D_A: 0.4562816023826599, Loss_D_B: 0.4809805154800415
Epoch [0/200], Step [20/84], Loss_G: 4.800670623779297, Loss_D_A: 0.472589910030365, Loss_D_B: 0.46967169642448425
Epoch [0/200], Step [30/84], Loss_G: 4.093136787414551, Loss_D_A: 0.46267759799957275, Loss_D_B: 0.45862317085266113
Epoch [0/200], Step [40/84], Loss_G: 4.799318313598633, Loss_D_A: 0.4323997497558594, Loss_D_B: 0.4219668209552765
Epoch [0/200], Step [50/84], Loss_G: 4.61391544342041, Loss_D_A: 0.46005725860595703, Loss_D_B: 0.43283239006996155
Epoch [0/200], Step [60/84], Loss_G: 3.641369104385376, Loss_D_A: 0.44722265005111694, Loss_D_B: 0.42901700735092163
Epoch [0/200], Step [70/84], Loss_G: 3.674103260040283, Loss_D_A: 0.47371405363082886, Loss_D_B: 0.435285747051239
Epoch [0/200], Step [80/84], Loss_G: 3.925704002380371, Loss_D_A: 0.4578661

In [15]:
import torch
from PIL import Image
import torchvision.transforms as transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
G_AB = Generator().to(device)
G_BA = Generator().to(device)


G_AB.load_state_dict(torch.load("checkpoints/G_AB_90.pth"))
G_BA.load_state_dict(torch.load("checkpoints/G_BA_90.pth"))

G_AB.eval()
G_BA.eval()

Generator(
  (model): Sequential(
    (0): Conv2d(1, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
    (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (5): ReLU(inplace=True)
    (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (7): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (8): ReLU(inplace=True)
    (9): ResnetBlock(
      (conv_block): Sequential(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (2): ReLU(inplace=True)
        (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): InstanceNor

In [16]:
def preprocess_image(image_path):
    transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])

    image = Image.open(image_path).convert("L")
    return transform(image).unsqueeze(0).to(device)

In [17]:
def generate_image(input_image_path, generator, output_image_path):

    input_image = preprocess_image(input_image_path)


    with torch.no_grad():
        fake_image = generator(input_image)


    fake_image = fake_image.squeeze(0).cpu().detach()
    fake_image = transforms.ToPILImage()(fake_image)


    fake_image.save(output_image_path)
    print(f"Generated image saved as: {output_image_path}")

In [21]:
import os

dataset_path = "/root/.cache/kagglehub/datasets/paultimothymooney/chest-xray-pneumonia/versions/2/chest_xray/chest_xray/"
print("Directory exists:", os.path.exists(dataset_path))
print("Contents:", os.listdir(dataset_path) if os.path.exists(dataset_path) else "Not found")


Directory exists: True
Contents: ['.DS_Store', 'val', 'train', 'test']


In [24]:
!kaggle datasets download -d paultimothymooney/chest-xray-pneumonia --unzip -p /root/.cache/kagglehub/datasets/


Dataset URL: https://www.kaggle.com/datasets/paultimothymooney/chest-xray-pneumonia
License(s): other
Downloading chest-xray-pneumonia.zip to /root/.cache/kagglehub/datasets
 99% 2.28G/2.29G [00:10<00:00, 266MB/s]
100% 2.29G/2.29G [00:10<00:00, 244MB/s]


In [26]:
!find /root/.cache/kagglehub/datasets/ -type d -name "NORMAL"


/root/.cache/kagglehub/datasets/paultimothymooney/chest-xray-pneumonia/versions/2/chest_xray/__MACOSX/chest_xray/val/NORMAL
/root/.cache/kagglehub/datasets/paultimothymooney/chest-xray-pneumonia/versions/2/chest_xray/__MACOSX/chest_xray/train/NORMAL
/root/.cache/kagglehub/datasets/paultimothymooney/chest-xray-pneumonia/versions/2/chest_xray/__MACOSX/chest_xray/test/NORMAL
/root/.cache/kagglehub/datasets/paultimothymooney/chest-xray-pneumonia/versions/2/chest_xray/val/NORMAL
/root/.cache/kagglehub/datasets/paultimothymooney/chest-xray-pneumonia/versions/2/chest_xray/train/NORMAL
/root/.cache/kagglehub/datasets/paultimothymooney/chest-xray-pneumonia/versions/2/chest_xray/chest_xray/val/NORMAL
/root/.cache/kagglehub/datasets/paultimothymooney/chest-xray-pneumonia/versions/2/chest_xray/chest_xray/train/NORMAL
/root/.cache/kagglehub/datasets/paultimothymooney/chest-xray-pneumonia/versions/2/chest_xray/chest_xray/test/NORMAL
/root/.cache/kagglehub/datasets/paultimothymooney/chest-xray-pneumo

In [28]:
import os


normal_dir = "/root/.cache/kagglehub/datasets/paultimothymooney/chest-xray-pneumonia/versions/2/chest_xray/__MACOSX/chest_xray/val/NORMAL"
pneumonia_dir = "/root/.cache/kagglehub/datasets/paultimothymooney/chest-xray-pneumonia/versions/2/chest_xray/__MACOSX/chest_xray/val/PNEUMONIA"


def get_valid_image_path(directory):
    files = [f for f in os.listdir(directory) if f.endswith(('.jpeg', '.jpg', '.png'))]
    if not files:
        raise FileNotFoundError(f"No valid images found in {directory}")
    return os.path.join(directory, files[0])


normal_image_path = get_valid_image_path(normal_dir)
pneumonia_image_path = get_valid_image_path(pneumonia_dir)

print("Normal X-ray:", normal_image_path)
print("Pneumonia X-ray:", pneumonia_image_path)

Normal X-ray: /root/.cache/kagglehub/datasets/paultimothymooney/chest-xray-pneumonia/versions/2/chest_xray/__MACOSX/chest_xray/val/NORMAL/._NORMAL2-IM-1436-0001.jpeg
Pneumonia X-ray: /root/.cache/kagglehub/datasets/paultimothymooney/chest-xray-pneumonia/versions/2/chest_xray/__MACOSX/chest_xray/val/PNEUMONIA/._person1952_bacteria_4883.jpeg


In [30]:
def get_valid_image_path(directory):
    files = [f for f in os.listdir(directory) if f.endswith(('.jpeg', '.jpg', '.png')) and not f.startswith("._")]
    if not files:
        raise FileNotFoundError(f"No valid images found in {directory}")
    return os.path.join(directory, files[0])


In [33]:
!find /root/.cache/kagglehub/datasets/ -name "._*" -delete


In [42]:
import os

def get_valid_image_path(directory):
    files = [f for f in os.listdir(directory) if f.endswith(('.jpeg', '.jpg', '.png')) and not f.startswith("._")]
    if not files:
        raise FileNotFoundError(f"No valid images found in {directory}")
    return os.path.join(directory, files[0])


In [43]:
import os

normal_dir = "/root/.cache/kagglehub/datasets/paultimothymooney/chest-xray-pneumonia/versions/2/chest_xray/chest_xray/val/NORMAL"

if os.path.exists(normal_dir):
    print("NORMAL directory exists.")
    print("Files in NORMAL directory:", os.listdir(normal_dir))
else:
    print("Directory does not exist!")


NORMAL directory exists.
Files in NORMAL directory: ['.DS_Store', 'NORMAL2-IM-1437-0001.jpeg', 'NORMAL2-IM-1436-0001.jpeg', 'NORMAL2-IM-1427-0001.jpeg', 'NORMAL2-IM-1430-0001.jpeg', 'NORMAL2-IM-1431-0001.jpeg', 'NORMAL2-IM-1438-0001.jpeg', 'NORMAL2-IM-1440-0001.jpeg', 'NORMAL2-IM-1442-0001.jpeg']


In [46]:
import os

def get_valid_image_path(directory):
    if not os.path.exists(directory):
        raise FileNotFoundError(f"Directory does not exist: {directory}")


    files = [f for f in os.listdir(directory) if f.endswith(('.jpeg', '.jpg', '.png')) and not f.startswith(".")]

    if not files:
        raise FileNotFoundError(f"No valid images found in {directory}")

    return os.path.join(directory, files[0])


normal_dir = "/root/.cache/kagglehub/datasets/paultimothymooney/chest-xray-pneumonia/versions/2/chest_xray/chest_xray/val/NORMAL"


normal_image_path = get_valid_image_path(normal_dir)

print("Valid image path:", normal_image_path)


Valid image path: /root/.cache/kagglehub/datasets/paultimothymooney/chest-xray-pneumonia/versions/2/chest_xray/chest_xray/val/NORMAL/NORMAL2-IM-1437-0001.jpeg


In [47]:
generate_image(normal_image_path, G_AB, "generated_pneumonia.png")

Generated image saved as: generated_pneumonia.png


In [49]:
!find /root/.cache/kagglehub/datasets/ -name "._*" -delete


In [50]:
import os

def get_valid_image_path(directory):
    if not os.path.exists(directory):
        raise FileNotFoundError(f"Directory does not exist: {directory}")


    files = [f for f in os.listdir(directory) if f.lower().endswith(('.jpeg', '.jpg', '.png')) and not f.startswith(".")]

    if not files:
        raise FileNotFoundError(f"No valid images found in {directory}")

    return os.path.join(directory, files[0])


pneumonia_dir = "/root/.cache/kagglehub/datasets/paultimothymooney/chest-xray-pneumonia/versions/2/chest_xray/chest_xray/val/PNEUMONIA"


pneumonia_image_path = get_valid_image_path(pneumonia_dir)

print("Valid Pneumonia image path:", pneumonia_image_path)


Valid Pneumonia image path: /root/.cache/kagglehub/datasets/paultimothymooney/chest-xray-pneumonia/versions/2/chest_xray/chest_xray/val/PNEUMONIA/person1947_bacteria_4876.jpeg


In [51]:
import os
print("File exists:", os.path.exists(pneumonia_image_path))


File exists: True


In [52]:
generate_image(pneumonia_image_path, G_BA, "generated_normal.png")

Generated image saved as: generated_normal.png


In [55]:
import os

file_path = "/root/.cache/kagglehub/datasets/paultimothymooney/chest-xray-pneumonia/versions/2/chest_xray/chest_xray/val/NORMAL/NORMAL2-IM-1437-0001.jpeg"

if os.path.exists(file_path):
    print("✅ File exists!")
else:
    print("❌ File NOT found! Check the path again.")

✅ File exists!


In [56]:
import torch
import gc

gc.collect()
torch.cuda.empty_cache()