In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import wandb
import datetime as dt
import numpy as np
import scipy.stats as stats

# Initialize Weights & Biases
wandb.init(project="gan-noise-investigation-2")

# Configuration for W&B
config = wandb.config
config.latent_size = 64
config.hidden_size = 256
config.image_size = 784
config.batch_size = 50
config.learning_rate = 0.0002
config.num_epochs = 200
config.noise_mean = 0.0   # Mean for normal and lognormal distribution
config.noise_std = 1.0    # Standard deviation for normal and lognormal
config.noise_min = -1.0   # Min for uniform distribution
config.noise_max = 1.0    # Max for uniform distribution
config.noise_lambda = 1.0 # Lambda for exponential and Poisson distribution
config.noise_alpha = 2.0  # Alpha for gamma distribution
config.noise_beta = 1.0   # Beta for gamma distribution
config.noise_type = 'gamma' #'lognormal' #'exponential' # 'uniform' #'normal'  # Default noise type


# Check for GPU and set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
])

# Load MNIST dataset
mnist = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
data_loader = torch.utils.data.DataLoader(dataset=mnist, batch_size=config.batch_size, shuffle=True)

# Define Generator
class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)
        self.fc4 = nn.Linear(hidden_size, input_size)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        # x = self.relu(self.fc2(x))
        x = self.relu(self.fc2(x))
        x = self.tanh(self.fc3(x))
        return x

# Define Discriminator
class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, 1)
        self.fc4 = nn.Linear(hidden_size, input_size)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        # x = self.relu(self.fc2(x))
        # x = self.relu(self.fc2(x))
        x = self.sigmoid(self.fc3(x))
        return x

# Instantiate models and move to device
G = Generator(config.latent_size, config.hidden_size, config.image_size).to(device)
D = Discriminator(config.image_size, config.hidden_size).to(device)

# Loss and optimizers
criterion = nn.BCELoss()
d_optimizer = optim.Adam(D.parameters(), lr=config.learning_rate)
g_optimizer = optim.Adam(G.parameters(), lr=config.learning_rate)


# Monobit Frequency Test: Check if the number of 1's and 0's are approximately equal.
def monobit_frequency_test(noise):
    n = len(noise)
    count_ones = np.sum(noise)
    S = abs(count_ones - (n - count_ones))
    p_value = stats.norm.cdf(S / np.sqrt(n))  # Normal distribution approximation
    return p_value

# Block Frequency Test: Check if blocks of the sequence have an equal number of 1's and 0's.
def block_frequency_test(noise, block_size=128):
    """
    Block Frequency Test: Check if blocks of the sequence have an equal number of 1's and 0's.

    Parameters:
        noise (torch.Tensor): Binary noise sequence.
        block_size (int): Size of each block to test.

    Returns:
        p_value (float): p-value of the test.
    """
    # Calculate the number of blocks
    num_blocks = len(noise) // block_size
    
    # Calculate the sum of 1's in each block
    block_sums = [np.sum(noise[i * block_size: (i + 1) * block_size]) for i in range(num_blocks)]
    
    # Convert block_sums to a NumPy array to allow element-wise operations
    block_sums = np.array(block_sums)
    
    # Perform the chi-squared test
    chi_squared = 4 * block_size * np.sum(((block_sums - block_size / 2) / block_size) ** 2)
    
    # Calculate the p-value from the chi-squared statistic
    p_value = stats.chi2.sf(chi_squared, num_blocks)  # Survival function
    
    return p_value


# Runs Test: Tests the randomness of a sequence by examining the number of runs.
def runs_test(noise):
    n = len(noise)
    count_ones = np.sum(noise)
    count_zeros = n - count_ones
    runs = 1 + np.sum(noise[1:] != noise[:-1])
    expected_runs = 2 * count_ones * count_zeros / n + 1
    variance_runs = 2 * count_ones * count_zeros * (2 * count_ones * count_zeros - n) / (n ** 2 * (n - 1))
    z = abs(runs - expected_runs) / np.sqrt(variance_runs)
    p_value = 2 * stats.norm.cdf(-z)  # Two-tailed test
    return p_value

# Longest Runs of Ones in a Block Test
def longest_runs_of_ones_test(noise, block_size=128):
    num_blocks = len(noise) // block_size
    longest_runs = [np.max(np.diff(np.where(np.concatenate(([0], noise[i * block_size:(i + 1) * block_size], [0])) == 0))) 
                    for i in range(num_blocks)]
    p_value = stats.chi2.sf(np.sum(longest_runs), num_blocks)
    return p_value

# Add more tests as necessary...

def noise_metrics(noise):
    """
    Calculate and log metrics for the generated noise.
    
    Metrics calculated:
        - Mean
        - Standard Deviation
        - Skewness
        - Kurtosis
        - Entropy
        - Range
        - Randomness Tests: Monobit Frequency, Block Frequency, Runs Test, Longest Runs of Ones
    
    Parameters:
        noise (torch.Tensor): The noise tensor to analyze.
    """
    noise_np = noise.cpu().numpy()

    # Statistical metrics
    mean = np.mean(noise_np)
    std = np.std(noise_np)
    skewness = np.mean((noise_np - mean) ** 3) / std**3
    kurtosis = np.mean((noise_np - mean) ** 4) / std**4 - 3
    noise_range = np.max(noise_np) - np.min(noise_np)
    
    hist, _ = np.histogram(noise_np, bins=100, density=True)
    hist = hist[hist > 0]
    entropy = -np.sum(hist * np.log(hist))

    # Convert noise to binary for randomness tests (e.g., threshold at 0.5)
    binary_noise = (noise_np > 0.5).astype(int)

    # Randomness tests
    p_monobit = monobit_frequency_test(binary_noise)
    p_block = block_frequency_test(binary_noise)
    p_runs = runs_test(binary_noise)
    p_longest_runs = longest_runs_of_ones_test(binary_noise)

    # Log metrics to Weights & Biases
    wandb.log({
        "Noise Mean": mean,
        "Noise Std": std,
        "Noise Skewness": skewness,
        "Noise Kurtosis": kurtosis,
        "Noise Range": noise_range,
        "Noise Entropy": entropy,
        "Monobit Frequency Test": p_monobit,
        "Block Frequency Test": p_block,
        "Runs Test": p_runs,
        "Longest Runs of Ones Test": p_longest_runs
    })


def generate_noise(batch_size, latent_size, noise_type='normal'):
    """
    Generates noise based on different distributions and logs metrics for the noise.

    Parameters:
        batch_size (int): The number of noise vectors to generate.
        latent_size (int): The size of each noise vector.
        noise_type (str): The type of distribution to sample from.

    Returns:
        torch.Tensor: The generated noise vector.
    """
    
    if noise_type == 'normal':
        z = torch.randn(batch_size, latent_size).to(device) * config.noise_std + config.noise_mean

    elif noise_type == 'uniform':
        z = torch.rand(batch_size, latent_size).to(device) * (config.noise_max - config.noise_min) + config.noise_min

    elif noise_type == 'exponential':
        z = torch.distributions.Exponential(config.noise_lambda).sample((batch_size, latent_size)).to(device)

    elif noise_type == 'lognormal':
        z = torch.distributions.LogNormal(config.noise_mean, config.noise_std).sample((batch_size, latent_size)).to(device)

    elif noise_type == 'gamma':
        z = torch.distributions.Gamma(config.noise_alpha, config.noise_beta).sample((batch_size, latent_size)).to(device)

    elif noise_type == 'poisson':
        z = torch.poisson(torch.full((batch_size, latent_size), config.noise_lambda)).to(device)

    elif noise_type == 'random_binary':
        z = torch.randint(0, 2, (batch_size, latent_size)).float().to(device)  # Binary random noise

    else:
        raise ValueError(f"Unsupported noise type: {noise_type}")

    # # Log noise distribution type and parameters
    # wandb.log({
    #     "Noise Type": noise_type,
    #     "Noise Mean": config.noise_mean if noise_type in ['normal', 'lognormal'] else -100,
    #     "Noise Std": config.noise_std if noise_type in ['normal', 'lognormal'] else -100,
    #     "Noise Lambda": config.noise_lambda if noise_type in ['exponential', 'poisson'] else -100,
    #     "Noise Alpha": config.noise_alpha if noise_type == 'gamma' else -100,
    #     "Noise Beta": config.noise_beta if noise_type == 'gamma' else -100,
    #     "Noise Min": config.noise_min if noise_type == 'uniform' else -100,
    #     "Noise Max": config.noise_max if noise_type == 'uniform' else -100
    # })

    # Calculate and log noise metrics
    # noise_metrics(z)

    return z

# Define the stopping criterion (number of consecutive epochs)
stop_epochs_threshold = 5  # Stop after this many consecutive epochs where the condition is met

# Initialize the counter for consecutive successful epochs
consecutive_epochs = 0


# Training loop
start_time = time.time()

for epoch in range(config.num_epochs):
    start_time = time.time()

    for i, (images, _) in enumerate(data_loader):
        images = images.reshape(config.batch_size, -1).to(device)
        real_labels = torch.ones(config.batch_size, 1).to(device)
        fake_labels = torch.zeros(config.batch_size, 1).to(device)

        # Train Discriminator
        outputs = D(images)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs

        # Generate noise with tunable parameters
        z = generate_noise(config.batch_size, config.latent_size, config.noise_type)
        fake_images = G(z)
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs

        d_loss = d_loss_real + d_loss_fake
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        # Train Generator
        z = generate_noise(config.batch_size, config.latent_size, config.noise_type)
        fake_images = G(z)
        outputs = D(fake_images)
        g_loss = criterion(outputs, real_labels)
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        # Print and log progress
        if (i + 1) % (len(data_loader) // 4) == 0:
            elapsed_time = time.time() - start_time
            print(f'Epoch [{epoch}/{config.num_epochs}], Step [{i + 1}/{len(data_loader)}], '
                  f'D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}, '
                  f'D(x): {real_score.mean().item():.2f}, D(G(z)): {fake_score.mean().item():.2f}, '
                  f'Time Elapsed: {elapsed_time:.2f} sec')

            noise_metrics(z)


            # Log metrics to W&B
            wandb.log({
                'Epoch': epoch,
                'D Loss': d_loss.item(),
                'G Loss': g_loss.item(),
                'D(x)': real_score.mean().item(),
                'D(G(z))': fake_score.mean().item(),
                'Time Elapsed': elapsed_time
            })
    
    # Stopping criterion
    if real_score.mean().item() >= 0.98 and fake_score.mean().item() <= 0.02:
        consecutive_epochs += 1
    else:
        consecutive_epochs = 0

    # Check if the stopping criterion is met
    if consecutive_epochs >= stop_epochs_threshold:
        print(f"Stopping training as D(x) ≈ 1.00 and D(G(z)) ≈ 0.00 for {stop_epochs_threshold} consecutive epochs.")
        break

    # Save and visualize generated images
    if (epoch + 1) % 20 == 0:
        with torch.no_grad():
            fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
            grid = torchvision.utils.make_grid(fake_images, nrow=10, normalize=True)
            plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
            plt.show()

# Save models
timestamp = dt.datetime.now().strftime("%y%m%d%H%M%S")
torch.save(G.state_dict(), f'models/generator_{timestamp}.pth')
torch.save(D.state_dict(), f'models/discriminator_{timestamp}.pth')

# End W&B run
wandb.finish()


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mbobbypestana[0m ([33mbobbypestana-kvantify[0m). Use [1m`wandb login --relogin`[0m to force relogin


Epoch [0/200], Step [300/1200], D Loss: 0.3354, G Loss: 2.3762, D(x): 0.94, D(G(z)): 0.23, Time Elapsed: 5.01 sec
Epoch [0/200], Step [600/1200], D Loss: 0.1203, G Loss: 5.3818, D(x): 0.96, D(G(z)): 0.06, Time Elapsed: 8.62 sec
Epoch [0/200], Step [900/1200], D Loss: 0.1545, G Loss: 3.5579, D(x): 0.98, D(G(z)): 0.12, Time Elapsed: 12.00 sec
Epoch [0/200], Step [1200/1200], D Loss: 0.0454, G Loss: 3.9309, D(x): 0.99, D(G(z)): 0.04, Time Elapsed: 15.66 sec
Epoch [1/200], Step [300/1200], D Loss: 0.0210, G Loss: 5.1269, D(x): 0.99, D(G(z)): 0.01, Time Elapsed: 3.73 sec
Epoch [1/200], Step [600/1200], D Loss: 0.0813, G Loss: 6.6114, D(x): 0.95, D(G(z)): 0.03, Time Elapsed: 7.14 sec
Epoch [1/200], Step [900/1200], D Loss: 0.0868, G Loss: 6.7348, D(x): 0.95, D(G(z)): 0.02, Time Elapsed: 10.51 sec
Epoch [1/200], Step [1200/1200], D Loss: 0.0284, G Loss: 4.9750, D(x): 0.99, D(G(z)): 0.02, Time Elapsed: 13.68 sec
Epoch [2/200], Step [300/1200], D Loss: 0.4066, G Loss: 6.9971, D(x): 0.82, D(G(z)

VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
D Loss,▅▂▃▁▁▂▂▁▅█▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
D(G(z)),▆▂▄▂▁▂▂▁▂█▁▁▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
D(x),▆▆▇█▇▆▆█▁▃▆███▇█████████████████████
Epoch,▁▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇████
G Loss,▁▃▂▂▃▄▄▃▄▃▃▄▃▃▃▃▄▄▃▅▃▃▆▄▄▄▆▇▇▇▇█████
Monobit Frequency Test,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Noise Entropy,▆▅█▃▅▆▅▇▆▅▅▅▅▁▆▆▅▅▃▇▅▆▆▅▄▅█▆▅▅▅▅▅▇▅▄
Noise Kurtosis,▂▃▂▃▃▂▃▁▂▃▃▃▃█▃▂▂▃▄▂▃▃▂▂▃▂▁▃▄▂▂▂▃▂▂▄
Noise Mean,█▃▄▆▄▅▁▆█▄▁▅▇█▆▃▄▇▄▂▄▅▅▄▄▅▆▂▄▆▆▄▃▃▃▅
Noise Range,▂▃▁▄▃▂▂▁▂▃▃▃▃█▂▂▃▃▄▁▂▂▂▃▃▃▁▂▃▃▃▂▃▂▃▄

0,1
Block Frequency Test,
D Loss,0.0
D(G(z)),0.0
D(x),1.0
Epoch,8.0
G Loss,13.1773
Longest Runs of Ones Test,
Monobit Frequency Test,1.0
Noise Entropy,12.61408
Noise Kurtosis,3.63846


: 