In [1]:
# Import necessary libraries
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from torchvision import transforms
from PIL import Image
from tqdm import tqdm


In [2]:
# Define the Generator
class Generator(nn.Module):
    def __init__(self, noise_dim):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(noise_dim + 1, 64, 4, 2, 1, bias=False),  # Output: (64, 128, 128)
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),  # Output: (128, 64, 64)
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),  # Output: (256, 32, 32)
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),  # Output: (512, 16, 16)
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.Conv2d(512, 1024, 4, 2, 1, bias=False),  # Output: (1024, 8, 8)
            nn.BatchNorm2d(1024),
            nn.ReLU(True),
            nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias=False),  # Output: (512, 16, 16)
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),  # Output: (256, 32, 32)
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),  # Output: (128, 64, 64)
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),  # Output: (64, 128, 128)
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False),  # Output: (1, 256, 256)
            nn.Tanh()
        )

    def forward(self, simulated_map, noise_map):
        # Concatenate simulated map and noise map along the channel dimension
        input = torch.cat((simulated_map, noise_map), 1)
        return self.main(input)


In [3]:
# Define the Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(2, 64, 4, 2, 1, bias=False),  # Output: (64, 128, 128)
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),  # Output: (128, 64, 64)
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),  # Output: (256, 32, 32)
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),  # Output: (512, 16, 16)
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 2, 1, bias=False),  # Output: (1, 8, 8)
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)


In [4]:
# Setup distributed environment
def setup(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)


In [5]:
# Build models and wrap with DDP
def build_model(rank, noise_dim):
    # Initialize generator and discriminator
    generator = Generator(noise_dim).to(rank)
    discriminator = Discriminator().to(rank)
    
    # Wrap models with DDP
    generator = DDP(generator, device_ids=[rank])
    discriminator = DDP(discriminator, device_ids=[rank])
    
    return generator, discriminator


In [6]:
# Define Custom Dataset
class CustomDataset(Dataset):
    def __init__(self, dpm_dir, irt_dir, transform=None):
        self.dpm_dir = dpm_dir
        self.irt_dir = irt_dir
        self.transform = transform
        self.dpm_images = sorted(os.listdir(dpm_dir))
        self.irt_images = sorted(os.listdir(irt_dir))

    def __len__(self):
        return len(self.dpm_images)

    def __getitem__(self, idx):
        dpm_image = Image.open(os.path.join(self.dpm_dir, self.dpm_images[idx]))
        irt_image = Image.open(os.path.join(self.irt_dir, self.irt_images[idx]))

        if self.transform:
            dpm_image = self.transform(dpm_image)
            irt_image = self.transform(irt_image)

        return dpm_image, irt_image


In [7]:
# Get DataLoader
def get_data_loader(rank, world_size, batch_size):
    transform = transforms.Compose([
        transforms.Grayscale(),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    dataset = CustomDataset(dpm_dir, irt_dir, transform=transform)
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    data_loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
    
    return data_loader


In [8]:
# Define Training Loop
def train(rank, world_size, noise_dim, batch_size, num_epochs, lr):
    setup(rank, world_size)
    
    # Initialize models and optimizers
    generator, discriminator = build_model(rank, noise_dim)
    criterion = nn.BCELoss().to(rank)
    d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
    g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
    
    data_loader = get_data_loader(rank, world_size, batch_size)
    
    for epoch in range(num_epochs):
        data_loader.sampler.set_epoch(epoch)
        for i, (simulated_map, measured_map) in enumerate(tqdm(data_loader)):
            simulated_map = simulated_map.to(rank)
            measured_map = measured_map.to(rank)
            batch_size = simulated_map.size(0)

            ############################
            # Train discriminator
            ############################
            d_optimizer.zero_grad()

            real_labels = torch.ones(batch_size, 1, device=rank)
            real_inputs = torch.cat((simulated_map, measured_map), dim=1)
            real_outputs = discriminator(real_inputs)
            d_loss_real = criterion(real_outputs, real_labels)
            d_loss_real.backward()

            noise_map = torch.randn(batch_size, noise_dim, 256, 256, device=rank)
            fake_images = generator(simulated_map, noise_map)
            fake_labels = torch.zeros(batch_size, 1, device=rank)
            fake_inputs = torch.cat((simulated_map, fake_images), dim=1)
            fake_outputs = discriminator(fake_inputs)
            d_loss_fake = criterion(fake_outputs, fake_labels)
            d_loss_fake.backward()

            d_loss = d_loss_real + d_loss_fake
            d_optimizer.step()

            ############################
            # Train generator
            ############################
            g_optimizer.zero_grad()

            noise_map = torch.randn(batch_size, noise_dim, 256, 256, device=rank)
            fake_images = generator(simulated_map, noise_map)
            fake_inputs = torch.cat((simulated_map, fake_images), dim=1)
            outputs = discriminator(fake_inputs)
            g_loss = criterion(outputs, real_labels)
            g_loss.backward()
            g_optimizer.step()

            ############################
            # Print losses
            ############################
            if i % 100 == 0 and rank == 0:
                print(f"Epoch [{epoch}/{num_epochs}], Step [{i}/{len(data_loader)}], "
                      f"D_loss: {d_loss.item():.4f}, G_loss: {g_loss.item():.4f}")

    if rank == 0:
        torch.save(generator.module.state_dict(), 'generator.pth')
        torch.save(discriminator.module.state_dict(), 'discriminator.pth')
    
    dist.destroy_process_group()


In [9]:
# Launch Training
def main():
    world_size = 4
    noise_dim = 1
    batch_size = 64
    num_epochs = 10
    lr = 0.0002

    mp.spawn(train,
             args=(world_size, noise_dim, batch_size, num_epochs, lr),
             nprocs=world_size,
             join=True)

if __name__ == "__main__":
    main()
