In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchsummary import summary
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
from PIL import Image
import os
import random
import matplotlib.pyplot as plt
from gen_disc_networks import Generator, Discriminator
from customDataSet import CustomDataset

In [2]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Hyperparameters

pixel_size = 256
batch_size = 4
noise_dim = 2
lr = 0.0002
num_epochs = 50
ifTrain = True
ifSaveModel = True
if ifTrain == True:
    print("true")
else:
    print("false")

G = Generator(noise_dim)
D = Discriminator()
#G.get_submodule

generator = G.to(device)
discriminator = D.to(device)

true


Getting the model summary

In [3]:
print("Generator summary")
summary(G, [(1, pixel_size, pixel_size), (noise_dim, pixel_size, pixel_size)])

#print("Discriminator summary")
#summary(D, (2, 256, 256))
#for parameter in G.parameters():
#    print(parameter)
# Initialize models


Generator summary
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 128, 128]           3,072
       BatchNorm2d-2         [-1, 64, 128, 128]             128
              ReLU-3         [-1, 64, 128, 128]               0
            Conv2d-4          [-1, 128, 64, 64]         131,072
       BatchNorm2d-5          [-1, 128, 64, 64]             256
              ReLU-6          [-1, 128, 64, 64]               0
            Conv2d-7          [-1, 256, 32, 32]         524,288
       BatchNorm2d-8          [-1, 256, 32, 32]             512
              ReLU-9          [-1, 256, 32, 32]               0
           Conv2d-10          [-1, 512, 16, 16]       2,097,152
      BatchNorm2d-11          [-1, 512, 16, 16]           1,024
             ReLU-12          [-1, 512, 16, 16]               0
           Conv2d-13           [-1, 1024, 8, 8]       8,388,608
      BatchNorm2d-14 

In [4]:
# Create dataset instance
transform = transforms.Compose([
    transforms.Resize((pixel_size, pixel_size)),  # Resize to pixel_size x pixel_size
    transforms.ToTensor()  # Convert to tensor
])
dataset_dpm_path = 'dataset/DPM/'
dataset_irt_path = 'dataset/IRT2/'

dataset = CustomDataset(dataset_dpm_path, dataset_irt_path, transform=transform)

# Create data loader
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Choose a random index
random_index = random.randint(0, len(dataset) - 1)
# Get the image at the random index

image_dpm, image_irt = dataset[random_index]

# Convert tensor to numpy array and remove batch dimension
image_dpm_np = image_dpm.squeeze().numpy()
image_irt_np = image_irt.squeeze().numpy()



In [None]:
if ifTrain == True:
    
    # Define loss function and optimizers
    criterion = nn.BCELoss()
    d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
    g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
    
    # Training loop
    for epoch in range(num_epochs):
        for i, batch in enumerate(tqdm(data_loader)):
            # Unpack the batch
            simulated_map, measured_map = batch
            print(simulated_map.size())
    
            # Move tensors to device
            simulated_map = simulated_map.to(device)
            measured_map = measured_map.to(device)
    
    
            ############################
            # Train discriminator
            ############################
            d_optimizer.zero_grad()
    
    
    
            # Train with real data
            real_labels = torch.ones(batch_size, 1, device=device)
            # Concatenate the simulated map and measured map along the channel dimension
            real_inputs = torch.cat((simulated_map, measured_map), dim=1)
            real_outputs = discriminator(real_inputs)
    #        print(f"real label: {real_labels.size()}")
    #        print(f"real output: {real_outputs.size()}")
            d_loss_real = criterion(real_outputs.squeeze(), real_labels.squeeze())
            d_loss_real.backward()
    
            # Train with fake data
            noise_map = torch.randn(batch_size, noise_dim, pixel_size, pixel_size, device=device)
            fake_images = generator(simulated_map, noise_map)
            fake_labels = torch.zeros(batch_size, 1, device=device)
    #        print(f"fake image size @generator's output: {fake_images.size()}")
    #        print(f"simulated_map size: {simulated_map.size()}")
            fake_inputs = torch.cat((simulated_map, fake_images), dim=1)
            fake_outputs = discriminator(fake_inputs)
    #        print(f"fake label: {fake_labels.size()}")
    #        print(f"fake output: {fake_outputs.size()}")
            d_loss_fake = criterion(fake_outputs.squeeze(), fake_labels.squeeze())
            d_loss_fake.backward()
    
            d_loss = d_loss_real + d_loss_fake
            d_optimizer.step()
    
            ############################
            # Train generator
            ############################
            g_optimizer.zero_grad()
    
            # Generate fake images
            noise_map = torch.randn(batch_size, noise_dim, pixel_size, pixel_size, device=device)
            fake_images = generator(simulated_map, noise_map)
    
            # Train generator with discriminator feedback
            fake_inputs = torch.cat((simulated_map, fake_images), dim=1)
            outputs = discriminator(fake_inputs)
            g_loss = criterion(outputs.squeeze(), real_labels.squeeze())
            g_loss.backward()
            g_optimizer.step()
    
    
    
            ############################
            # Print losses
            ############################
            if i % 10 == 0:
                print(f"Epoch [{epoch}/{num_epochs}], Step [{i}/{len(data_loader)}], "
                      f"D_loss: {d_loss.item():.4f}, G_loss: {g_loss.item():.4f}")
                if ifSaveModel == True:
                    torch.save(generator.state_dict(), 'generator_trained.pth')
                    torch.save(generator, 'generator_entire_model.pth')
                    print("Model saved at epoch {i}")

  0%|                                   | 0/20 [00:00<?, ?it/s]

torch.Size([4, 1, 256, 256])
Epoch [0/50], Step [0/20], D_loss: 1.5260, G_loss: 2.1778


  5%|█▎                         | 1/20 [00:03<01:02,  3.31s/it]

Model saved at epoch {i}
torch.Size([4, 1, 256, 256])


 10%|██▋                        | 2/20 [00:06<00:59,  3.33s/it]

torch.Size([4, 1, 256, 256])


 15%|████                       | 3/20 [00:10<00:58,  3.43s/it]

torch.Size([4, 1, 256, 256])


 20%|█████▍                     | 4/20 [00:13<00:55,  3.47s/it]

torch.Size([4, 1, 256, 256])


 25%|██████▊                    | 5/20 [00:17<00:51,  3.42s/it]

torch.Size([4, 1, 256, 256])


 30%|████████                   | 6/20 [00:20<00:47,  3.40s/it]

torch.Size([4, 1, 256, 256])


 35%|█████████▍                 | 7/20 [00:23<00:43,  3.37s/it]

torch.Size([4, 1, 256, 256])


 40%|██████████▊                | 8/20 [00:27<00:40,  3.36s/it]

torch.Size([4, 1, 256, 256])


 45%|████████████▏              | 9/20 [00:30<00:36,  3.33s/it]

torch.Size([4, 1, 256, 256])


 50%|█████████████             | 10/20 [00:33<00:33,  3.31s/it]

torch.Size([4, 1, 256, 256])
Epoch [0/50], Step [10/20], D_loss: 0.0549, G_loss: 7.0921


 55%|██████████████▎           | 11/20 [00:37<00:30,  3.42s/it]

Model saved at epoch {i}
torch.Size([4, 1, 256, 256])


 60%|███████████████▌          | 12/20 [00:40<00:26,  3.30s/it]

torch.Size([4, 1, 256, 256])


 65%|████████████████▉         | 13/20 [00:43<00:22,  3.27s/it]

torch.Size([4, 1, 256, 256])


 70%|██████████████████▏       | 14/20 [00:47<00:19,  3.33s/it]

torch.Size([4, 1, 256, 256])


 75%|███████████████████▌      | 15/20 [00:50<00:16,  3.37s/it]

torch.Size([4, 1, 256, 256])


 80%|████████████████████▊     | 16/20 [00:53<00:13,  3.32s/it]

torch.Size([4, 1, 256, 256])


 85%|██████████████████████    | 17/20 [00:56<00:09,  3.31s/it]

torch.Size([4, 1, 256, 256])


 90%|███████████████████████▍  | 18/20 [01:00<00:06,  3.33s/it]

torch.Size([4, 1, 256, 256])


 95%|████████████████████████▋ | 19/20 [01:03<00:03,  3.29s/it]

torch.Size([4, 1, 256, 256])


100%|██████████████████████████| 20/20 [01:06<00:00,  3.34s/it]
  0%|                                   | 0/20 [00:00<?, ?it/s]

torch.Size([4, 1, 256, 256])
Epoch [1/50], Step [0/20], D_loss: 0.0080, G_loss: 7.3448


  5%|█▎                         | 1/20 [00:03<01:11,  3.74s/it]

Model saved at epoch {i}
torch.Size([4, 1, 256, 256])


 10%|██▋                        | 2/20 [00:07<01:02,  3.49s/it]

torch.Size([4, 1, 256, 256])


 15%|████                       | 3/20 [00:10<00:57,  3.36s/it]

torch.Size([4, 1, 256, 256])


 20%|█████▍                     | 4/20 [00:13<00:53,  3.32s/it]

torch.Size([4, 1, 256, 256])


 25%|██████▊                    | 5/20 [00:16<00:49,  3.30s/it]

torch.Size([4, 1, 256, 256])


 30%|████████                   | 6/20 [00:20<00:45,  3.28s/it]

torch.Size([4, 1, 256, 256])


 35%|█████████▍                 | 7/20 [00:23<00:42,  3.27s/it]

torch.Size([4, 1, 256, 256])


 40%|██████████▊                | 8/20 [00:26<00:39,  3.26s/it]

torch.Size([4, 1, 256, 256])


 45%|████████████▏              | 9/20 [00:29<00:35,  3.25s/it]

torch.Size([4, 1, 256, 256])


 50%|█████████████             | 10/20 [00:33<00:32,  3.29s/it]

torch.Size([4, 1, 256, 256])
Epoch [1/50], Step [10/20], D_loss: 0.0172, G_loss: 7.8593


 55%|██████████████▎           | 11/20 [00:36<00:30,  3.41s/it]

Model saved at epoch {i}
torch.Size([4, 1, 256, 256])


 60%|███████████████▌          | 12/20 [00:39<00:25,  3.25s/it]

torch.Size([4, 1, 256, 256])


 65%|████████████████▉         | 13/20 [00:42<00:22,  3.23s/it]

torch.Size([4, 1, 256, 256])


 70%|██████████████████▏       | 14/20 [00:46<00:19,  3.22s/it]

torch.Size([4, 1, 256, 256])


 75%|███████████████████▌      | 15/20 [00:49<00:16,  3.25s/it]

torch.Size([4, 1, 256, 256])


 80%|████████████████████▊     | 16/20 [00:52<00:13,  3.25s/it]

torch.Size([4, 1, 256, 256])


 85%|██████████████████████    | 17/20 [00:55<00:09,  3.25s/it]

torch.Size([4, 1, 256, 256])


 90%|███████████████████████▍  | 18/20 [00:59<00:06,  3.29s/it]

torch.Size([4, 1, 256, 256])


 95%|████████████████████████▋ | 19/20 [01:03<00:03,  3.46s/it]

torch.Size([4, 1, 256, 256])


100%|██████████████████████████| 20/20 [01:07<00:00,  3.38s/it]
  0%|                                   | 0/20 [00:00<?, ?it/s]

torch.Size([4, 1, 256, 256])
Epoch [2/50], Step [0/20], D_loss: 0.0140, G_loss: 7.8427


  5%|█▎                         | 1/20 [00:04<01:18,  4.13s/it]

Model saved at epoch {i}
torch.Size([4, 1, 256, 256])


 10%|██▋                        | 2/20 [00:07<01:07,  3.75s/it]

torch.Size([4, 1, 256, 256])


 15%|████                       | 3/20 [00:10<00:59,  3.50s/it]

torch.Size([4, 1, 256, 256])


 20%|█████▍                     | 4/20 [00:14<00:54,  3.40s/it]

torch.Size([4, 1, 256, 256])


 25%|██████▊                    | 5/20 [00:17<00:51,  3.42s/it]

torch.Size([4, 1, 256, 256])


 30%|████████                   | 6/20 [00:20<00:47,  3.42s/it]

torch.Size([4, 1, 256, 256])


 35%|█████████▍                 | 7/20 [01:13<04:12, 19.45s/it]

torch.Size([4, 1, 256, 256])


 40%|██████████▊                | 8/20 [01:16<02:50, 14.23s/it]

torch.Size([4, 1, 256, 256])


 45%|████████████▏              | 9/20 [01:19<01:58, 10.77s/it]

torch.Size([4, 1, 256, 256])


 50%|█████████████             | 10/20 [01:22<01:24,  8.44s/it]

torch.Size([4, 1, 256, 256])
Epoch [2/50], Step [10/20], D_loss: 0.0000, G_loss: 75.0374


 55%|██████████████▎           | 11/20 [01:26<01:02,  6.92s/it]

Model saved at epoch {i}
torch.Size([4, 1, 256, 256])


 60%|███████████████▌          | 12/20 [01:29<00:45,  5.68s/it]

torch.Size([4, 1, 256, 256])


 65%|████████████████▉         | 13/20 [01:32<00:34,  4.89s/it]

torch.Size([4, 1, 256, 256])


 70%|██████████████████▏       | 14/20 [01:35<00:25,  4.30s/it]

torch.Size([4, 1, 256, 256])


 75%|███████████████████▌      | 15/20 [01:38<00:19,  3.90s/it]

torch.Size([4, 1, 256, 256])


 80%|████████████████████▊     | 16/20 [01:41<00:14,  3.63s/it]

torch.Size([4, 1, 256, 256])


 85%|██████████████████████    | 17/20 [01:44<00:10,  3.44s/it]

torch.Size([4, 1, 256, 256])


 90%|███████████████████████▍  | 18/20 [01:47<00:06,  3.31s/it]

torch.Size([4, 1, 256, 256])


In [None]:
generator = Generator(noise_dim).to(device)
generator.load_state_dict(torch.load('generator_trained.pth'))
generator.eval()

In [None]:
# Generate a new image
#from synthetic_map import generate_image
def generate_image(generator, simulated_map, noise_dim):
    # Prepare simulated map and noise map
    simulated_map = simulated_map.to(device)
    noise_map = torch.randn(1, noise_dim, pixel_size, pixel_size, device=device)

    # Generate image
    with torch.no_grad():
        generated_image = generator(simulated_map, noise_map)
    
    return generated_image

simulated_map = torch.from_numpy(image_dpm_np).reshape(1, 1 , pixel_size, pixel_size)
synthetic_image = generate_image(generator, simulated_map, noise_dim).numpy()
synthetic_image = synthetic_image.squeeze()
print(synthetic_image.shape)
#print(generated_image.size())

# Plot the image

plt.imshow(image_dpm_np, cmap='gray')
plt.axis('off')
plt.show()

plt.imshow(synthetic_image, cmap='gray')
plt.axis('off')
plt.show()

In [None]:
print(synthetic_image.squeeze().shape)