In [1]:
from GAN import Generator, Discriminator
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

from Signal_Generator import *
from Signal_Analyzer import *

In [2]:
dataset = []

for i in range(10):
    SG = Signal_Generator(num_sources=10, noise_amplitude=1)
    signals = SG.generating_signal()
    params = SG.printing_parameters()
    signal = signals['Signal'].values

    signal_tensor = torch.tensor(signal, dtype=torch.float).unsqueeze(0).unsqueeze(0)
    params_tensor = torch.tensor(params, dtype=torch.float).unsqueeze(0)

    dataset.append((signal_tensor, params_tensor))

num_latent_variables = 10
z = torch.randn(1, num_latent_variables, 1)

In [3]:
generator = Generator(in_channels=1, num_latent_variables=num_latent_variables, length=len(signal), num_parameters=len(params))
discriminator = Discriminator(input_channels=1, num_latent_variables=num_latent_variables, length=len(signal), num_parameters=len(params))

In [4]:
criterion = nn.BCELoss()

learning_rate = 0.0001
num_epochs = 10

optimizer_G = torch.optim.Adam(generator.parameters(), lr=learning_rate)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=learning_rate)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
for epoch in range(num_epochs):
    for i,data in enumerate(dataset):
        signal_tensor, params_tensor = data
        
        # Generate Fake Parameters
        fake_params = generator(signal_tensor, z).detach().squeeze().numpy()
        fake_params_tensor = torch.tensor(fake_params, dtype=torch.float).unsqueeze(0)
        
        # Train Discriminator
        discriminator.zero_grad()

        real_output = discriminator(signal_tensor, params_tensor, z)
        fake_output = discriminator(signal_tensor, fake_params_tensor, z)

        loss_D = criterion(real_output, torch.ones_like(real_output)) + criterion(fake_output, torch.zeros_like(fake_output))
        loss_D.backward()
        optimizer_D.step()
        
        # Train Generator
        generator.zero_grad()
        fake_output = discriminator(signal_tensor, fake_params_tensor, z)
        loss_G = criterion(fake_output, torch.ones_like(fake_output))
        loss_G.backward()
        optimizer_G.step()
        
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss_D: {loss_D.item():.4f}, Loss_G: {loss_G.item():.4f}")

Epoch [1/10], Loss_D: 1.5011, Loss_G: 0.9237
Epoch [2/10], Loss_D: 1.3960, Loss_G: 0.5942
Epoch [3/10], Loss_D: 1.3806, Loss_G: 0.7149
Epoch [4/10], Loss_D: 1.3834, Loss_G: 0.7515
Epoch [5/10], Loss_D: 1.3818, Loss_G: 0.7311
Epoch [6/10], Loss_D: 1.3801, Loss_G: 0.7081
Epoch [7/10], Loss_D: 1.3795, Loss_G: 0.7071
Epoch [8/10], Loss_D: 1.3791, Loss_G: 0.7145
Epoch [9/10], Loss_D: 1.3784, Loss_G: 0.7162
Epoch [10/10], Loss_D: 1.3775, Loss_G: 0.7136


In [6]:
SG = Signal_Generator(num_sources=10, noise_amplitude=1)
signals = SG.generating_signal()
params = SG.printing_parameters()
signal = signals['Signal'].values

print(params)

signal_tensor = torch.tensor(signal, dtype=torch.float).unsqueeze(0).unsqueeze(0)
params_tensor = torch.tensor(params, dtype=torch.float).unsqueeze(0)

generate_params = generator(signal_tensor, z).detach().squeeze().numpy()
print(generate_params)

[11.932606705122195, 11.162418137025993, 10.4756841253002, 4.987917542700986, 6.680472242895795, 9.476344061090538, 11.49344133978031, 7.710859755200287, 8.150962593003921, 9.69722944345947, 0.5280247517513108, 0.29307557035745196, 0.4125361499479564, 0.23165350097477755, 0.3668298432325661, 0.19174394102041625, 0.30709077623889464, 0.2659907904225629, 0.2302718133085357, 0.18864355760945117, 2.1258778163883654, 3.1314028956899485, 5.787724128895, 1.2106760330313815, 2.1463099429743857, 5.6504466748531845, 1.952982369564885, 5.114640044810341, 1.4848609981232601, 1.6946353361577309]
[-0.14385965  0.28918138  0.13915028 -0.18603382  0.31020144  0.21562594
  0.21877581 -0.07718515  0.39901143 -0.06567101 -0.2909413   0.38289204
 -0.15082283  0.15140295 -0.2656356   0.25412095 -0.3449543  -0.01768567
  0.11298192 -0.29028493  0.06241315 -0.01468661 -0.15705107 -0.11825426
 -0.11011943 -0.5264281  -0.11834784 -0.1210531  -0.34998825  0.7186213 ]


In [7]:
len(params)

30

In [8]:
len(generate_params)

30