# InforGan

## classes

- Generator generates different images based on the random class and continuous features.
- QHead classifies the classes and do tghe regression with the continuous feature.
- Generator is trained to help QHead solve its problem better, meaning the content and styles
    from class and features should be distinguishable.
- Discriminator also discriminate real or fake.


In [48]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class Generator(nn.Module):
    def __init__(self, input_size, hidden_dim, output_size):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_size, hidden_dim),
            nn.ReLU(True),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(True),
            nn.Linear(hidden_dim, output_size),
            nn.Tanh()
        )

    def forward(self, z, classes, continuous):
        x = torch.cat([z, classes, continuous], 1)
        return self.net(x)

class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_dim, output_dim=1):
        super(Discriminator, self).__init__()
        self.feature_extractor = nn.Sequential(
            nn.Linear(input_size, hidden_dim),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, output_dim),
            nn.Sigmoid()
        )

    def forward(self, x):
        features = self.feature_extractor(x)
        classification = self.classifier(features)
        return classification, features

class QHead(nn.Module):
    def __init__(self, input_dim, num_categorical, num_continuous):
        super(QHead, self).__init__()
        self.num_categorical = num_categorical
        self.num_continuous = num_continuous
        
        self.fc = nn.Linear(input_dim, 128)
        self.fc_cat = nn.Linear(128, num_categorical) # For categorical codes
        self.fc_cont_mu = nn.Linear(128, num_continuous) # For continuous codes - mean
        self.fc_cont_var = nn.Linear(128, num_continuous) # For continuous codes - variance

    def forward(self, x):
        x = F.leaky_relu(self.fc(x), 0.1)
        categorical = None
        if self.num_categorical > 0:
            categorical = F.softmax(self.fc_cat(x), dim=1)
        mu = self.fc_cont_mu(x)
        var = torch.exp(self.fc_cont_var(x)) # To ensure variance is positive
        return categorical, mu, var

z_dim = 100
c_dim = 10
g_hidden = 128
d_hidden = 128
img_dim = 784
num_continuouss = 24

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
netG = Generator(z_dim + num_continuouss + 1, g_hidden, img_dim).to(device)
netD = Discriminator(img_dim, d_hidden).to(device)
netQ = QHead(d_hidden, c_dim, num_continuouss).to(device)

optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerQ = optim.Adam(netQ.parameters(), lr=0.0002, betas=(0.5, 0.999))

num_epochs = 1

n_batch = 32

image = torch.randn(n_batch, 784) # from (32, 3, 28, 28)

criterionD = nn.BCELoss()
criterionQ = nn.CrossEntropyLoss()

for epoch in range(num_epochs):
    for real_images in [image]:
        batch_size = real_images.size(0)
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)
        
        netD.zero_grad()
        real_output, real_features = netD(real_images)
        errD_real = criterionD(real_output, real_labels)
        errD_real.backward()

        continuous_target = torch.randn(n_batch, num_continuouss)

        noise = torch.randn(batch_size, z_dim, device=device)
        random_classes = torch.randint(0, c_dim , (batch_size, 1)).to(device)
        fake_images = netG(noise, random_classes, continuous_target)
        fake_output, fake_features = netD(fake_images.detach())
        errD_fake = criterionD(fake_output, fake_labels)
        errD_fake.backward()
        errD = errD_real + errD_fake
        optimizerD.step()

        netG.zero_grad()

        fake_images = netG(noise, random_classes, continuous_target)
        output, _ = netD(fake_images)
        errG = criterionD(output, real_labels)

        errG.backward(retain_graph=True)

        netQ.zero_grad()
        fake_images = netG(noise, random_classes, continuous_target)
        _, fake_features = netD(fake_images)
        categorical_output, mu, var = netQ(fake_features)
        errQ = criterionQ(categorical_output, random_classes.squeeze())
        gaussian_loss = -torch.sum(-0.5 * torch.log(var) - 0.5 * ((continuous_target - mu) ** 2) / var)
        errQ += gaussian_loss

        errQ.backward()
        optimizerG.step()
        optimizerQ.step()


In [60]:
num_classes = 1  # Assuming a single class
num_continuous = 24
z_dim = 100  # Size of the noise vector

# Fix the categorical class
class_vector = torch.ones(1, num_classes)

# Prepare a fixed value for most continuous features
fixed_value = 0  # Assuming features are normally distributed around 0
fixed_continuous = torch.full((1, num_continuous), fixed_value)

# Choose a range of values to vary for one or two features
feature_to_vary = torch.linspace(-2, 2, steps=20)

# Generate images while varying selected features
generated_images = []
for i in feature_to_vary:
    # Example: Varying the first continuous feature
    continuous_vector = fixed_continuous.clone()
    continuous_vector[0, 0] = i

    # Combine with random noise
    noise = torch.randn(1, z_dim)

    # Generate an image
    with torch.no_grad():
        netG.eval()  # Set the generator to evaluation mode
        generated_image = netG(noise, class_vector, continuous_vector)  # Add batch dimension
        generated_images.append(generated_image)

len(generated_images)

20