<a href="https://colab.research.google.com/github/johan-stph/thesis/blob/main/1000images_with_heatmaps.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Prep

In [None]:
!pip install grad-cam

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet50
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from torchvision.models import ResNet50_Weights
from torchvision import models
import matplotlib.pyplot as plt
import numpy as np
import cv2

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to 224x224 to match the ResNet input size
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

batchsize = 10

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batchsize, shuffle=False)

### Resnet


In [None]:
model_resnet = resnet50(weights=ResNet50_Weights.DEFAULT)
model_resnet.fc = torch.nn.Linear(2048, 10)  # CIFAR-10 has 10 classes
model_resnet.to(device);

In [None]:
target_layers_res = [model_resnet.layer4[-1]]

# Initialize Grad-CAM
cam_res = GradCAM(model=model_resnet, target_layers=target_layers_res, use_cuda=torch.cuda.is_available())

### Densenet

In [None]:
model_dense = models.densenet121(weights=models.DenseNet121_Weights.DEFAULT)
model_dense.to(device)
model_dense.eval();

In [None]:
target_layers_dense = [model_dense.features.denseblock4.denselayer16]  # Last layer in the last dense block
cam_dense = GradCAM(model=model_dense, target_layers=target_layers_dense, use_cuda=torch.cuda.is_available())

### VGG19

In [None]:
model_vgg19 = models.vgg19(weights=models.VGG19_Weights.DEFAULT)
model_vgg19.to(device)
model_vgg19.eval();

In [None]:
target_layers_vgg19 = [model_vgg19.features[-1]]  # Last convolutional layer
cam_vgg19 = GradCAM(model=model_vgg19, target_layers=target_layers_vgg19, use_cuda=torch.cuda.is_available())

## Create Dataset

In [None]:
# Initialize lists to store data
all_images = []
all_labels = []
all_combined_heatmaps = []

N = 100  # Number of batches to process
# batch size = 10 to avoid overflow of GPU memory

for batch, (images, labels) in enumerate(trainloader):
    if batch > N:
        break

    targets = [ClassifierOutputTarget(label.item()) for label in labels]
    grayscale_cams_vgg19 = cam_vgg19(input_tensor=images, targets=targets, aug_smooth=True, eigen_smooth=True)
    grayscale_cams_dense = cam_dense(input_tensor=images, targets=targets, aug_smooth=True, eigen_smooth=True)
    grayscale_cams_res = cam_res(input_tensor=images, targets=targets, aug_smooth=True, eigen_smooth=True)

    # Combine heatmaps
    stacked_heatmaps = np.stack((grayscale_cams_vgg19, grayscale_cams_dense, grayscale_cams_res))
    combined_heatmaps = np.max(stacked_heatmaps, axis=0)

    # Append to lists
    all_images.append(images.cpu().numpy())
    all_labels.append(labels.cpu().numpy())
    all_combined_heatmaps.append(combined_heatmaps)




In [None]:
import matplotlib.pyplot as plt

# Convert a sample tensor image to numpy array and denormalize
def denormalize(tensor_img):
    return ((tensor_img * 0.5) + 0.5).clamp(0, 1).numpy().transpose(1, 2, 0)

# Take the first image, label, and heatmap for demonstration
sample_image = denormalize(torch.tensor(all_images[0][0]))  # Convert to shape (height, width, channels)
sample_label = all_labels[0][0]
sample_heatmap = all_combined_heatmaps[0][0]  # Assuming shape (height, width)

# Plotting
fig, axs = plt.subplots(1, 2, figsize=(12, 6))

# Original image
axs[0].imshow(sample_image)
axs[0].set_title(f'Original Image - Label: {sample_label}')
axs[0].axis('off')

# Heatmap
axs[1].imshow(sample_heatmap, cmap='jet')
axs[1].set_title('Combined Heatmap')
axs[1].axis('off')

plt.show()

In [None]:


# Convert lists to NumPy arrays
all_images_array = np.array(all_images)
all_labels_array = np.array(all_labels)
all_combined_heatmaps_array = np.array(all_combined_heatmaps)

# Save to disk
np.save('all_images.npy', all_images_array)
np.save('all_labels.npy', all_labels_array)
np.save('all_combined_heatmaps.npy', all_combined_heatmaps_array)


In [None]:
save_path = '/content/drive/My Drive/bachelorarbeit/cifar10/1000images/'
np.save(save_path + 'all_images.npy', all_images_array)
np.save(save_path + 'all_labels.npy', all_labels_array)
np.save(save_path + 'all_combined_heatmaps.npy', all_combined_heatmaps_array)


# AB HIER STARTEN

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision import datasets
import matplotlib.pyplot as plt
import cv2
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
from torchvision.io import read_image
from torchvision.transforms import ToTensor
import torch.nn.functional as F
from sklearn import preprocessing
from sklearn.model_selection import train_test_split


In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device


In [None]:
from google.colab import drive
drive.mount('/content/drive')


In [None]:
load_path = '/content/drive/My Drive/bachelorarbeit/cifar10/1000images/'
all_images_array = np.load(load_path + 'all_images.npy')
all_labels_array = np.load(load_path + 'all_labels.npy')
all_combined_heatmaps_array = np.load(load_path + 'all_combined_heatmaps.npy')

In [None]:
images_tensor = torch.tensor(all_images_array, dtype=torch.float32)
labels_tensor = torch.tensor(all_labels_array, dtype=torch.long)
heatmaps_tensor = torch.tensor(all_combined_heatmaps_array, dtype=torch.float32)


In [None]:
images_tensor = images_tensor.reshape(-1, 3, 224, 224)
images_tensor.shape #torch.Size([1010, 3, 224, 224])

In [None]:
labels_tensor = labels_tensor.reshape(-1)
labels_tensor.shape #torch.Size([1010])

In [None]:
heatmaps_tensor = heatmaps_tensor.reshape((-1, 224, 224)).unsqueeze(1)
heatmaps_tensor.shape # torch.Size([1010, 1, 224, 224])

In [None]:

# Function to create a Gaussian kernel
def create_gaussian_kernel(kernel_size=9, sigma=1.5):
    x_coord = torch.arange(kernel_size)
    x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size)
    y_grid = x_grid.t()
    xy_grid = torch.stack([x_grid, y_grid], dim=-1).float()

    mean = (kernel_size - 1) / 2.
    variance = sigma ** 2.

    gaussian_kernel = torch.exp(
        -torch.sum((xy_grid - mean) ** 2., dim=-1) / (2 * variance)
    )
    gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)

    return gaussian_kernel

# Create a more aggressive Gaussian kernel
gaussian_kernel = create_gaussian_kernel(kernel_size=51, sigma=20.0) 
gaussian_kernel = gaussian_kernel.view(1, 1, 51, 51) 

# Apply Gaussian smoothing
smoothed_heatmaps = F.conv2d(heatmaps_tensor, gaussian_kernel, padding=25) 

smoothed_heatmaps.shape

In [None]:
def denormalize(tensor_img):
    return ((tensor_img * 0.5) + 0.5).clamp(0, 1).numpy().transpose(1, 2, 0)

sample_image = denormalize(torch.tensor(images_tensor[0]))  # Convert to shape (height, width, channels)
sample_label = labels_tensor[0]
sample_heatmap = heatmaps_tensor[0].detach().numpy().transpose((1, 2, 0))
smoothed_heatmap = smoothed_heatmaps[0].detach().numpy().transpose((1, 2, 0))

print("Original Heatmap Values:", sample_heatmap[0:5, 0:5, 0])  
print("Smoothed Heatmap Values:", smoothed_heatmap[0:5, 0:5, 0])




# Plotting
fig, axs = plt.subplots(1, 3, figsize=(16, 6))

# Original image
axs[0].imshow(sample_image)
axs[0].set_title(f'Original Image - Label: {sample_label}')
axs[0].axis('off')

# Original Heatmap
axs[1].imshow((sample_heatmap * 255).astype(np.uint8), cmap='hot')
axs[1].set_title('Original Heatmap')
axs[1].axis('off')

# Smoothed Heatmap
axs[2].imshow((smoothed_heatmap  * 255).astype(np.uint8), cmap='hot')
axs[2].set_title('Smoothed Heatmap')
axs[2].axis('off')

# Difference Heatmap
# Calculate the difference

plt.show()



### Test and Training Dataset

In [None]:
class MyDataset(Dataset):
    def __init__(self, images_tensor, labels_tensor, heatmaps_tensor):
        self.images = images_tensor
        self.labels = labels_tensor
        self.heatmaps = heatmaps_tensor

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        heatmap = self.heatmaps[idx]

        return image, label, heatmap

# Create dataset
dataset = MyDataset(images_tensor, labels_tensor, smoothed_heatmaps)


In [None]:
from torch.utils.data import random_split

# Calculate lengths for train/test split
total_len = len(dataset)
print(total_len)
train_len = int(0.95 * total_len)
test_len = total_len - train_len

# Perform the split
train_dataset, test_dataset = random_split(dataset, [train_len, test_len])

In [None]:
image_one, label_one, heatmap_one = dataset[0]
image_one.shape

# Generator

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels)
        )
        self.relu = nn.LeakyReLU(0.2)

    def forward(self, x):
        residual = x
        x = self.block(x)
        x += residual
        x = self.relu(x)
        return x

class Generator(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super(Generator, self).__init__()

        # Encoder
        self.conv1 = self.conv_block(in_channels, 64)
        self.conv2 = self.conv_block(64, 128)
        self.conv3 = self.conv_block(128, 256)
        self.conv4 = self.conv_block(256, 512)

        # Middle part
        self.res_block = ResidualBlock(512, 512)

        # Decoder (Transpose Convolution to upsample)
        self.deconv1 = self.deconv_block(512, 256)
        self.deconv2 = self.deconv_block(256, 128)
        self.deconv3 = self.deconv_block(128, 64)
        self.deconv4 = self.deconv_block(64, out_channels, last_layer=True)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
            nn.MaxPool2d(2, stride=2)
        )

    def deconv_block(self, in_channels, out_channels, last_layer=False):
        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
        ]
        if last_layer:
            layers.append(nn.Sigmoid())
        else:
            layers.append(nn.LeakyReLU(0.2))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.res_block(x)
        x = self.deconv1(x)
        x = self.deconv2(x)
        x = self.deconv3(x)
        x = self.deconv4(x)
        return x


In [None]:
image_one, label_one, heatmap_one = dataset[0]
image_one = image_one.unsqueeze(0)
print(image_one.shape)

test_generator = Generator()
test_generator(image_one).shape

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=4):
        super(Discriminator, self).__init__()

        # More complex model
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, stride=2, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, padding=1, bias=False)  # Output layer
        )

    def forward(self, heatmap, img):
        heatmap = F.interpolate(heatmap, size=(224, 224))
        img_input = torch.cat((img, heatmap), dim=1)  # Concatenate along the channel dimension
        return self.model(img_input)




In [1]:
# Test the Discriminator
image, label, heatmap = torch.randn(3, 224, 224), torch.randn(1), torch.randn(1, 224, 224)
test_discriminator = Discriminator()
image = image.unsqueeze(0)
heatmap = heatmap.unsqueeze(0)
print("Input Heatmap Shape:", heatmap.shape)  # [1, 1, 224, 224]
print("Input Image Shape:", image.shape)  # [1, 3, 224, 224]
output = test_discriminator(heatmap, image)
print("Output Shape:", output.shape)  # Output shape

NameError: ignored

In [None]:
def train(dataloader, generator, discriminator, optim_g, optim_d, loss_fn, device, epochs, display_images=5):
    generator.train()
    discriminator.train()

    # Initialize lists to store losses
    generator_losses = []
    discriminator_losses = []

    for epoch in range(epochs):
        for batch, (image, label, heatmap) in enumerate(dataloader):
            image, heatmap = image.to(device), heatmap.to(device)

            # Generate color images from grayscale images
            generated_images = generator(image)

            # Compute loss for discriminator
            real_output = discriminator(heatmap, image)
            fake_output = discriminator(generated_images.detach(), image)

            real_loss = loss_fn(real_output, torch.ones_like(real_output))
            fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output))
            discriminator_loss = (real_loss + fake_loss) / 2

            optim_d.zero_grad()
            discriminator_loss.backward()
            optim_d.step()

            # Compute loss for generator
            fake_output = discriminator(generated_images, image)
            generator_loss = loss_fn(fake_output, torch.ones_like(fake_output))

            optim_g.zero_grad()
            generator_loss.backward()
            optim_g.step()

            if batch % 100 == 0:
                print(f"Epoch {epoch+1}/{epochs}, Batch {batch}/{len(dataloader)}, Generator Loss: {generator_loss.item()}, Discriminator Loss: {discriminator_loss.item()}")

        # Append losses for this epoch
        generator_losses.append(generator_loss.item())
        discriminator_losses.append(discriminator_loss.item())

        # Visualization after each epoch
        image = image.cpu().detach()
        generated_images = generated_images.cpu().detach().numpy().transpose((0, 2, 3, 1))
        heatmap = heatmap.cpu().detach().numpy().transpose((0, 2, 3, 1))
        fig = plt.figure(figsize=(30, 4))
        for i in range(display_images):
            # Display grayscale input image
            ax = fig.add_subplot(3, display_images, i + 1, xticks=[], yticks=[])
            ax.imshow(denormalize(image[i]))

            # Display generated heatmap
            ax = fig.add_subplot(3, display_images, i + 1 + display_images, xticks=[], yticks=[])
            ax.imshow((generated_images[i] * 255).astype(np.uint8), cmap='hot')

            # Display original heatmap
            ax = fig.add_subplot(3, display_images, i + 1 + 2 * display_images, xticks=[], yticks=[])
            ax.imshow((heatmap[i] * 255).astype(np.uint8), cmap='hot')

        plt.show()

        # Plot the generator and discriminator loss every 10 epochs
        if (epoch + 1) % 10 == 0:
            plt.figure(figsize=(10, 5))
            plt.title("Generator and Discriminator Loss During Training")
            plt.plot(generator_losses, label="Generator")
            plt.plot(discriminator_losses, label="Discriminator")
            plt.xlabel("Epochs")
            plt.ylabel("Loss")
            plt.legend()
            plt.show()
        if (epoch + 1) % 20 == 0:
            torch.save(generator.state_dict(), f'generator_gan_v1_epoch_{epoch+1}.pth')


In [None]:
loss_fn = nn.BCEWithLogitsLoss()
epochs = 100
batch_size = 32
lr_disc = 0.0002
lr_gener = 0.0002

betas = (0.5, 0.999)

In [None]:
generator = Generator().to(device)
discriminator = Discriminator().to(device)


optim_g = torch.optim.Adam(generator.parameters(), lr=lr_gener, betas=betas)
optim_d = torch.optim.Adam(discriminator.parameters(), lr=lr_disc, betas=betas)


train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

In [None]:
train(train_dataloader, generator, discriminator, optim_g, optim_d, loss_fn, device, epochs)

# train v2

In [None]:
generator_losses = []
discriminator_losses = []

def train_v2(dataloader, generator, discriminator, optim_g, optim_d, loss_fn, device, epochs, display_images=5):
    generator.train()
    discriminator.train()

    # Initialize lists to store losses


    for epoch in range(epochs):
        for batch, (image, label, heatmap) in enumerate(dataloader):
            image, heatmap = image.to(device), heatmap.to(device)

            # Generate heatmaps from images
            generated_images = generator(image)

            # Compute loss for discriminator
            real_output = discriminator(heatmap, image)
            fake_output = discriminator(generated_images.detach(), image)

            real_loss = loss_fn(real_output, torch.ones_like(real_output))
            fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output))
            discriminator_loss = (real_loss + fake_loss) / 2

            optim_d.zero_grad()
            discriminator_loss.backward()
            optim_d.step()

            # Compute loss for generator
            fake_output = discriminator(generated_images, image)
            gan_loss = loss_fn(fake_output, torch.ones_like(fake_output))

            # Compute MSE loss between the generated and real heatmaps
            mse_loss = F.mse_loss(generated_images, heatmap)

            # Combine the GAN loss and the MSE loss
            lambda_factor = 0.8  # Tune this factor
            generator_loss = gan_loss + lambda_factor * mse_loss

            optim_g.zero_grad()
            generator_loss.backward()
            optim_g.step()

            if batch % 100 == 0:
                print(f"Epoch {epoch+1}/{epochs}, Batch {batch}/{len(dataloader)}, Generator Loss: {generator_loss.item()}, Discriminator Loss: {discriminator_loss.item()}")

        # Append losses for this epoch
        generator_losses.append(generator_loss.item())
        discriminator_losses.append(discriminator_loss.item())

        # Visualization after each epoch
        image = image.cpu().detach()
        generated_images = generated_images.cpu().detach().numpy().transpose((0, 2, 3, 1))
        heatmap = heatmap.cpu().detach().numpy().transpose((0, 2, 3, 1))

        fig = plt.figure(figsize=(30, 4))
        for i in range(display_images):
            # Display grayscale input image
            ax = fig.add_subplot(3, display_images, i + 1, xticks=[], yticks=[])
            ax.imshow(denormalize(image[i]))

            # Display generated heatmap
            ax = fig.add_subplot(3, display_images, i + 1 + display_images, xticks=[], yticks=[])
            ax.imshow((generated_images[i] * 255).astype(np.uint8), cmap='hot')

            # Display original heatmap
            ax = fig.add_subplot(3, display_images, i + 1 + 2 * display_images, xticks=[], yticks=[])
            ax.imshow((heatmap[i] * 255).astype(np.uint8), cmap='hot')

        plt.show()

        # Plot the generator and discriminator loss every 10 epochs
        if (epoch + 1) % 10 == 0:
            plt.figure(figsize=(10, 5))
            plt.title("Generator and Discriminator Loss During Training")
            plt.plot(generator_losses, label="Generator")
            plt.plot(discriminator_losses, label="Discriminator")
            plt.xlabel("Epochs")
            plt.ylabel("Loss")
            plt.legend()
            plt.show()
        if (epoch + 1) % 20 == 0:
            torch.save(generator.state_dict(), f'generator_gan_v2_epoch_{epoch+1}.pth')

In [None]:
generator_v2 = Generator().to(device)
discriminator_v2 = Discriminator().to(device)


optim_g_v2 = torch.optim.Adam(generator_v2.parameters(), lr=lr_gener, betas=betas)
optim_d_v2 = torch.optim.Adam(discriminator_v2.parameters(), lr=lr_disc, betas=betas)


train_dataloader_v2 = DataLoader(dataset, batch_size=batch_size, shuffle=False)

In [None]:
train_v2(train_dataloader_v2, generator_v2, discriminator_v2, optim_g_v2, optim_d_v2, loss_fn, device, epochs)

# Gan with Context

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels)
        )
        self.relu = nn.LeakyReLU(0.2)

    def forward(self, x):
        residual = x
        x = self.block(x)
        x += residual
        x = self.relu(x)
        return x

class Generator_v2(nn.Module):
    def __init__(self, num_classes=10, in_channels=3, out_channels=1):
        super(Generator_v2, self).__init__()

        # Embedding for class label
        self.embedding = nn.Embedding(num_classes, 50)  # 50 is the size of the embedding vector

        # Encoder
        self.conv1 = self.conv_block(in_channels + 50, 64)  # Concatenated channel size becomes in_channels + 50
        self.conv2 = self.conv_block(64, 128)
        self.conv3 = self.conv_block(128, 256)
        self.conv4 = self.conv_block(256, 512)

        # Middle part
        self.res_block = ResidualBlock(512, 512)

        # Decoder (Transpose Convolution to upsample)
        self.deconv1 = self.deconv_block(512, 256)
        self.deconv2 = self.deconv_block(256, 128)
        self.deconv3 = self.deconv_block(128, 64)
        self.deconv4 = self.deconv_block(64, out_channels, last_layer=True)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
            nn.MaxPool2d(2, stride=2)
        )

    def deconv_block(self, in_channels, out_channels, last_layer=False):
        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
        ]
        if last_layer:
            layers.append(nn.Sigmoid())
        else:
            layers.append(nn.LeakyReLU(0.2))
        return nn.Sequential(*layers)

    def forward(self, x, labels):
        # Embed labels and reshape them to concatenate with images
        labels = self.embedding(labels)
        labels = labels.view(labels.size(0), 50, 1, 1)
        labels = labels.expand(-1, -1, x.size(2), x.size(3))

        # Concatenate labels with images
        x = torch.cat([x, labels], dim=1)

        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.res_block(x)
        x = self.deconv1(x)
        x = self.deconv2(x)
        x = self.deconv3(x)
        x = self.deconv4(x)
        return x


In [None]:
class Discriminator_v2(nn.Module):
    def __init__(self, num_classes=10, in_channels=4):
        super(Discriminator_v2, self).__init__()

        # Embedding for class label
        self.embedding = nn.Embedding(num_classes, 50)

        # Discriminator model
        self.model = nn.Sequential(
            nn.Conv2d(in_channels + 50, 64, 4, stride=2, padding=1),  # in_channels + 50 due to label embedding
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, stride=2, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, padding=1, bias=False)
        )

    def forward(self, heatmap, img, labels):
        heatmap = F.interpolate(heatmap, size=(224, 224))  # Resize heatmap to match image size

        # Embed labels and reshape to concatenate with images and heatmaps
        labels = self.embedding(labels)
        labels = labels.view(labels.size(0), 50, 1, 1)
        labels = labels.expand(-1, -1, heatmap.size(2), heatmap.size(3))

        # Concatenate label embedding, heatmap and image
        img_input = torch.cat((img, heatmap, labels), dim=1)  # Concatenate along channel dimension

        return self.model(img_input)


In [None]:
def jaccard_index(heatmap1, heatmap2):
    # Convert to binary
    heatmap1_binary = (heatmap1 > 0.2).float()
    heatmap2_binary = (heatmap2 > 0.2).float()

    # Compute Jaccard Index (Intersection over Union)
    intersection = torch.sum(heatmap1_binary * heatmap2_binary)
    union = torch.sum(heatmap1_binary) + torch.sum(heatmap2_binary) - intersection

    IoU = intersection / union
    return IoU


In [None]:
mse_losses = []

In [None]:
generator_losses = []

discriminator_losses = []
jaccard_indices = []

def train_v3(dataloader, generator, discriminator, optim_g, optim_d, loss_fn, device, epochs, display_images=5):
    generator.train()
    discriminator.train()

    # Initialize lists to store losses


    for epoch in range(epochs):
        for batch, (image, label, heatmap) in enumerate(dataloader):
            image, heatmap, label = image.to(device), heatmap.to(device), label.to(device)

            # Generate heatmaps from images and labels
            generated_images = generator(image, label)

            IoU = jaccard_index(generated_images, heatmap)

            # Compute loss for discriminator
            real_output = discriminator(heatmap, image, label)
            fake_output = discriminator(generated_images.detach(), image, label)

            real_loss = loss_fn(real_output, torch.ones_like(real_output))
            fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output))
            discriminator_loss = (real_loss + fake_loss) / 2

            optim_d.zero_grad()
            discriminator_loss.backward()
            optim_d.step()

            # Compute loss for generator
            fake_output = discriminator(generated_images, image, label)
            gan_loss = loss_fn(fake_output, torch.ones_like(fake_output))

            # Compute MSE loss between the generated and real heatmaps
            mse_loss = F.mse_loss(generated_images, heatmap)
            mse_losses.append(mse_loss.item())

            # Combine the GAN loss and the MSE loss
            lambda_factor = 0.8  # Tune this factor
            generator_loss = gan_loss + lambda_factor * mse_loss

            optim_g.zero_grad()
            generator_loss.backward()
            optim_g.step()

            if batch % 100 == 0:
                print(f"Epoch {epoch+1}/{epochs}, Batch {batch}/{len(dataloader)}, Generator Loss: {generator_loss.item()}, Discriminator Loss: {discriminator_loss.item()}, Jaccard Index: {IoU.item()}")

        # Append losses for this epoch
        generator_losses.append(generator_loss.item())
        discriminator_losses.append(discriminator_loss.item())
        jaccard_indices.append(IoU.item())

        # Visualization after each epoch
        image = image.cpu().detach()
        generated_images = generated_images.cpu().detach().numpy().transpose((0, 2, 3, 1))
        heatmap = heatmap.cpu().detach().numpy().transpose((0, 2, 3, 1))

        fig = plt.figure(figsize=(30, 4))
        for i in range(display_images):
            # Display grayscale input image
            ax = fig.add_subplot(3, display_images, i + 1, xticks=[], yticks=[])
            ax.imshow(denormalize(image[i]))

            # Display generated heatmap
            ax = fig.add_subplot(3, display_images, i + 1 + display_images, xticks=[], yticks=[])
            ax.imshow((generated_images[i] * 255).astype(np.uint8), cmap='hot')

            # Display original heatmap
            ax = fig.add_subplot(3, display_images, i + 1 + 2 * display_images, xticks=[], yticks=[])
            ax.imshow((heatmap[i] * 255).astype(np.uint8), cmap='hot')

        plt.show()

        # Plot the generator and discriminator loss every 10 epochs
        if (epoch + 1) % 10 == 0:
            plt.figure(figsize=(10, 5))
            plt.title("Generator and Discriminator Loss During Training")
            plt.plot(generator_losses, label="Generator")
            plt.plot(discriminator_losses, label="Discriminator")
            plt.xlabel("Epochs")
            plt.ylabel("Loss")
            plt.legend()
            plt.show()
        if (epoch + 1) % 20 == 0:
            torch.save(generator.state_dict(), f'generator_cgan_loss-0_epoch_{epoch+1}.pth')


In [None]:
print(mse_losses)

[]


In [None]:
loss_fn = nn.BCEWithLogitsLoss()
epochs = 100
batch_size = 32
lr_disc = 0.0002
lr_gener = 0.0002

betas = (0.5, 0.999)

In [None]:
generator_v3 = Generator_v2().to(device)
discriminator_v3 = Discriminator_v2().to(device)


optim_g_v3 = torch.optim.Adam(generator_v3.parameters(), lr=lr_gener, betas=betas)
optim_d_v3 = torch.optim.Adam(discriminator_v3.parameters(), lr=lr_disc, betas=betas)


train_dataloader_v3 = DataLoader(dataset, batch_size=batch_size, shuffle=False)

In [None]:
train_v3(train_dataloader_v3, generator_v3, discriminator_v3, optim_g_v3, optim_d_v3, loss_fn, device, epochs)

## Evaluation of GAN Modell


### Storing of generator

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

def visualize_results(dataloader, generator, device, num_images=100):
    generator.eval()  # Set generator to evaluation mode
    count = 0

    for batch, (image, label, heatmap) in enumerate(dataloader):
        if count >= num_images:
            break

        image, label = image.to(device), label.to(device)

        # Generate heatmaps from images and labels
        with torch.no_grad():
            generated_images = generator(image) # removed labels

        # Move to CPU for visualization
        image = image.cpu().detach()
        generated_images = generated_images.cpu().detach().numpy().transpose((0, 2, 3, 1))
        heatmap = heatmap.cpu().detach().numpy().transpose((0, 2, 3, 1))

        for i in range(image.shape[0]):
            if count >= num_images:
                break

            fig = plt.figure(figsize=(15, 5))

            # Display grayscale input image
            ax = fig.add_subplot(1, 3, 1, xticks=[], yticks=[])
            ax.imshow(denormalize(image[i]))
            ax.set_title("Input Image")

            # Display generated heatmap
            ax = fig.add_subplot(1, 3, 2, xticks=[], yticks=[])
            ax.imshow((generated_images[i] * 255).astype(np.uint8), cmap='hot')
            ax.set_title("Generated Heatmap")

            # Display original heatmap
            ax = fig.add_subplot(1, 3, 3, xticks=[], yticks=[])
            ax.imshow((heatmap[i] * 255).astype(np.uint8), cmap='hot')
            ax.set_title("Original Heatmap")

            plt.show()

            count += 1


In [None]:
import torch.nn as nn

mse_loss = nn.MSELoss()

def jaccard_index(heatmap1, heatmap2):
    # Convert to binary
    heatmap1_binary = (heatmap1 > 0.15).float()
    heatmap2_binary = (heatmap2 > 0.15).float()

    # Compute Jaccard Index (Intersection over Union)
    intersection = torch.sum(heatmap1_binary * heatmap2_binary)
    union = torch.sum(heatmap1_binary) + torch.sum(heatmap2_binary) - intersection

    IoU = intersection / union
    return IoU


In [None]:
def calculate_jaccard_gan(dataloader, generator, device, num_images=100):
    generator.eval()  # Set generator to evaluation mode
    count = 0
    jaccard_sum = 0.0  # Initialize variable to store sum of all Jaccard indices

    for batch, (image, label, heatmap) in enumerate(dataloader):
        if count >= num_images:
            break

        image, label = image.to(device), label.to(device)

        # Generate heatmaps from images and labels
        with torch.no_grad():
            generated_images = generator(image)  # removed labels

        # Move to CPU for further computation
        generated_images = generated_images.cpu().detach()
        heatmap = heatmap.cpu().detach()

        for i in range(image.shape[0]):
            if count >= num_images:
                break

            # Calculate the Jaccard index for each pair of generated and original heatmaps
            jaccard = jaccard_index(generated_images[i], heatmap[i])
            jaccard_sum += jaccard.item()  # Accumulate Jaccard index
            count += 1

    average_jaccard = jaccard_sum / num_images  # Compute the average Jaccard index
    return average_jaccard


In [None]:
def calculate_jaccard_cgan(dataloader, generator, device, num_images=100):
    generator.eval()  # Set generator to evaluation mode
    count = 0
    jaccard_sum = 0.0  # Initialize variable to store sum of all Jaccard indices

    for batch, (image, label, heatmap) in enumerate(dataloader):
        if count >= num_images:
            break

        image, label = image.to(device), label.to(device)

        # Generate heatmaps from images and labels
        with torch.no_grad():
            generated_images = generator(image, label)  # removed labels

        # Move to CPU for further computation
        generated_images = generated_images.cpu().detach()
        heatmap = heatmap.cpu().detach()

        for i in range(image.shape[0]):
            if count >= num_images:
                break

            # Calculate the Jaccard index for each pair of generated and original heatmaps
            jaccard = jaccard_index(generated_images[i], heatmap[i])
            jaccard_sum += jaccard.item()  # Accumulate Jaccard index
            count += 1

    average_jaccard = jaccard_sum / num_images  # Compute the average Jaccard index
    return average_jaccard


In [None]:
def calculate_mse_cgan(dataloader, generator, device, num_images=100):
    generator.eval()  # Set generator to evaluation mode
    count = 0
    jaccard_sum = 0.0  # Initialize variable to store sum of all Jaccard indices

    for batch, (image, label, heatmap) in enumerate(dataloader):
        if count >= num_images:
            break

        image, label = image.to(device), label.to(device)

        # Generate heatmaps from images and labels
        with torch.no_grad():
            generated_images = generator(image, label)  # removed labels

        # Move to CPU for further computation
        generated_images = generated_images.cpu().detach()
        heatmap = heatmap.cpu().detach()

        for i in range(image.shape[0]):
            if count >= num_images:
                break

            # Calculate the Jaccard index for each pair of generated and original heatmaps
            jaccard = mse_loss(generated_images[i], heatmap[i])
            jaccard_sum += jaccard.item()  # Accumulate Jaccard index
            count += 1

    average_jaccard = jaccard_sum / num_images  # Compute the average Jaccard index
    return average_jaccard

In [None]:
def calculate_mse_gan(dataloader, generator, device, num_images=100):
    generator.eval()  # Set generator to evaluation mode
    count = 0
    jaccard_sum = 0.0  # Initialize variable to store sum of all Jaccard indices

    for batch, (image, label, heatmap) in enumerate(dataloader):
        if count >= num_images:
            break

        image, label = image.to(device), label.to(device)

        # Generate heatmaps from images and labels
        with torch.no_grad():
            generated_images = generator(image)  # removed labels

        # Move to CPU for further computation
        generated_images = generated_images.cpu().detach()
        heatmap = heatmap.cpu().detach()

        for i in range(image.shape[0]):
            if count >= num_images:
                break

            # Calculate the Jaccard index for each pair of generated and original heatmaps
            jaccard = mse_loss(generated_images[i], heatmap[i])
            jaccard_sum += jaccard.item()  # Accumulate Jaccard index
            count += 1

    average_jaccard = jaccard_sum / num_images  # Compute the average Jaccard index
    return average_jaccard

### load models from gdrive

In [None]:
import os

In [None]:
!cd "/content/drive/My Drive/bachelorarbeit/models/"; ls


generator_cgan_epoch_100.pth	     generator_gan_v1_epoch_100.pth
generator_cgan_epoch_20.pth	     generator_gan_v1_epoch_20.pth
generator_cgan_epoch_40.pth	     generator_gan_v1_epoch_40.pth
generator_cgan_epoch_60.pth	     generator_gan_v1_epoch_60.pth
generator_cgan_epoch_80.pth	     generator_gan_v1_epoch_80.pth
generator_cgan_loss-0_epoch_100.pth  generator_gan_v2_epoch_100.pth
generator_cgan_loss-0_epoch_20.pth   generator_gan_v2_epoch_20.pth
generator_cgan_loss-0_epoch_40.pth   generator_gan_v2_epoch_40.pth
generator_cgan_loss-0_epoch_60.pth   generator_gan_v2_epoch_60.pth
generator_cgan_loss-0_epoch_80.pth   generator_gan_v2_epoch_80.pth


In [None]:
base_path = "/content/drive/My Drive/bachelorarbeit/models/"

In [None]:
gan_v1_files = [
    "generator_gan_v1_epoch_20.pth",
    "generator_gan_v1_epoch_40.pth",
    "generator_gan_v1_epoch_60.pth",
    "generator_gan_v1_epoch_80.pth",
    "generator_gan_v1_epoch_100.pth"
]
gan_v2_files = [
    "generator_gan_v2_epoch_20.pth",
    "generator_gan_v2_epoch_40.pth",
    "generator_gan_v2_epoch_60.pth",
    "generator_gan_v2_epoch_80.pth",
    "generator_gan_v2_epoch_100.pth",
]

In [None]:
cgan_v1_files = [
    "generator_cgan_loss-0_epoch_20.pth",
    "generator_cgan_loss-0_epoch_40.pth",
    "generator_cgan_loss-0_epoch_60.pth",
    "generator_cgan_loss-0_epoch_80.pth",
    "generator_cgan_loss-0_epoch_100.pth"
]
cgan_v2_files = [
    "generator_cgan_epoch_20.pth",
    "generator_cgan_epoch_40.pth",
    "generator_cgan_epoch_60.pth",
    "generator_cgan_epoch_80.pth",
    "generator_cgan_epoch_100.pth"
]

In [None]:
gan_v1 = []
gan_v2 = []
cgan_v1 = []
cgan_v2 = []
for file_name in gan_v1_files:
  file_name = base_path + file_name
  gen_model = Generator().to(device)
  gen_model.load_state_dict(torch.load(file_name))
  gan_v1.append(gen_model)
for file_name in gan_v2_files:
  file_name = base_path + file_name
  gen_model = Generator().to(device)
  gen_model.load_state_dict(torch.load(file_name))
  gan_v2.append(gen_model)

for file_name in cgan_v1_files:
  file_name = base_path + file_name
  gen_model = Generator_v2().to(device)
  gen_model.load_state_dict(torch.load(file_name))
  cgan_v1.append(gen_model)
for file_name in cgan_v2_files:
  file_name = base_path + file_name
  gen_model = Generator_v2().to(device)
  gen_model.load_state_dict(torch.load(file_name))
  cgan_v2.append(gen_model)


In [None]:
your_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

Vergleich des Jaccard Index der verschiedenen Generation mit jeweils 10 Bildern

In [None]:
jaccard_gan_v1 = []
mse_loss_gan_v1 = []
for generator in gan_v1:
  dl = DataLoader(dataset, batch_size = 32, shuffle=False)
  jaccard_gan_v1.append(calculate_jaccard_gan(dl, generator, device))
  mse_loss_gan_v1.append(calculate_mse_gan(dl, generator, device))
print(jaccard_gan_v1)
print(mse_loss_gan_v1)

In [None]:
jaccard_gan_v2 = []
mse_loss_gan_v2 = []
for generator in gan_v2:
  dl = DataLoader(dataset, batch_size = 32, shuffle=False)
  jaccard_gan_v2.append(calculate_jaccard_gan(dl, generator, device))
  mse_loss_gan_v2.append(calculate_mse_gan(dl, generator, device))
print(jaccard_gan_v2)
print(mse_loss_gan_v2)

In [None]:
jaccard_cgan_v1 = []
mse_loss_cgan_v1 = []
for generator in cgan_v1:
  dl = DataLoader(dataset, batch_size = 32, shuffle=False)
  jaccard_cgan_v1.append(calculate_jaccard_cgan(dl, generator, device))
  mse_loss_cgan_v1.append(calculate_mse_cgan(dl, generator, device))
print(jaccard_cgan_v1)
print(mse_loss_cgan_v1)

In [None]:
jaccard_cgan_v2 = []
mse_loss_cgan_v2 = []
for generator in cgan_v2:
  dl = DataLoader(dataset, batch_size = 32, shuffle=False)
  jaccard_cgan_v2.append(calculate_jaccard_cgan(dl, generator, device))
  mse_loss_cgan_v2.append(calculate_mse_cgan(dl, generator, device))
print(jaccard_cgan_v2)
print(mse_loss_cgan_v1)

In [None]:
epoch_label = ['20', '40', '60', '80', '100']

# Create the line chart with figsize
plt.figure(figsize=(10, 6))

# Plot for GAN V1
plt.plot(epoch_label, jaccard_gan_v1, marker='o', linestyle='-', label='GAN V1')

# Plot for GAN V2
plt.plot(epoch_label, jaccard_gan_v2, marker='x', linestyle='--', label='GAN V2')

# Plot for CGAN V1
plt.plot(epoch_label, jaccard_cgan_v1, marker='s', linestyle='-.', label='CGAN V1')

# Plot for CGAN V2
plt.plot(epoch_label, jaccard_cgan_v2, marker='d', linestyle=':', label='CGAN V2')

# Add labels, title, and grid
plt.xlabel('Epochs')
plt.ylabel('Jaccard - Index')
plt.title('Jaccard - Index for Different GAN and CGAN Versions at Various Epochs')
plt.grid(True)

# Add a legend to differentiate the lines
plt.legend()

# Show the plot
plt.show()


In [None]:
epoch_label = ['20', '40', '60', '80', '100']

# Create the line chart with figsize
plt.figure(figsize=(10, 6))

# Plot for GAN V1
plt.plot(epoch_label, mse_loss_gan_v1, marker='o', linestyle='-', label='GAN V1')

# Plot for GAN V2
plt.plot(epoch_label, mse_loss_gan_v2, marker='x', linestyle='--', label='GAN V2')

# Plot for CGAN V1
plt.plot(epoch_label, mse_loss_cgan_v1, marker='s', linestyle='-.', label='CGAN V1')

# Plot for CGAN V2
plt.plot(epoch_label, mse_loss_cgan_v2, marker='d', linestyle=':', label='CGAN V2')

# Add labels, title, and grid
plt.xlabel('Epochs')
plt.ylabel('MSE - Loss')
plt.title('MSE - Loss for Different GAN and CGAN Versions at Various Epochs')
plt.grid(True)

# Add a legend to differentiate the lines
plt.legend()

# Show the plot
plt.show()



Vergleich der 4 besten Methoden auf jeweils 100 Bildern

Visualisierung

In [None]:
for i, gen_model in enumerate(gens):
    print("Visualizing for one of the generators...", i + 1)
    visualize_results(DataLoader(dataset, batch_size=32, shuffle=False), gen_model, device, num_images=5)

In [None]:
visualize_results(your_dataloader, generator_v3, device, num_images=5)

NameError: ignored