In [1]:
import torch

from train_gan import (
    Critic,
    EuXFELCurrentDataModule,
    EuXFELCurrentDataset,
    Generator,
    ConvolutionalDecoder,
    ConvolutionalEncoder,
)


In [2]:
formfactor_encoder = ConvolutionalEncoder(signal_dims=240, latent_dims=10)
formfactor_encoder


ConvolutionalEncoder(
  (convnet): Sequential(
    (0): Conv1d(1, 8, kernel_size=(3,), stride=(2,), padding=(1,))
    (1): LeakyReLU(negative_slope=0.01)
    (2): Conv1d(8, 16, kernel_size=(3,), stride=(2,), padding=(1,))
    (3): LeakyReLU(negative_slope=0.01)
    (4): Conv1d(16, 32, kernel_size=(3,), stride=(2,), padding=(1,))
    (5): LeakyReLU(negative_slope=0.01)
  )
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (mlp): Sequential(
    (0): Linear(in_features=960, out_features=100, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=100, out_features=50, bias=True)
    (3): LeakyReLU(negative_slope=0.01)
    (4): Linear(in_features=50, out_features=10, bias=True)
  )
)

In [3]:
formfactor = torch.rand(3, 240)
encoded = formfactor_encoder(formfactor)
print(f"{encoded = }")
print(f"{encoded.requires_grad = }")


encoded = tensor([[-0.0202,  0.0085, -0.0483,  0.1294, -0.1184,  0.0755,  0.1278,  0.0924,
          0.1252,  0.0485],
        [-0.0221,  0.0060, -0.0486,  0.1302, -0.1174,  0.0758,  0.1283,  0.0915,
          0.1242,  0.0501],
        [-0.0209,  0.0079, -0.0460,  0.1289, -0.1160,  0.0720,  0.1328,  0.0942,
          0.1250,  0.0492]], grad_fn=<AddmmBackward0>)
encoded.requires_grad = True


In [4]:
current_encoder = ConvolutionalEncoder(signal_dims=300, latent_dims=10)
current_encoder


ConvolutionalEncoder(
  (convnet): Sequential(
    (0): Conv1d(1, 8, kernel_size=(3,), stride=(2,), padding=(1,))
    (1): LeakyReLU(negative_slope=0.01)
    (2): Conv1d(8, 16, kernel_size=(3,), stride=(2,), padding=(1,))
    (3): LeakyReLU(negative_slope=0.01)
    (4): Conv1d(16, 32, kernel_size=(3,), stride=(2,), padding=(1,))
    (5): LeakyReLU(negative_slope=0.01)
  )
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (mlp): Sequential(
    (0): Linear(in_features=1216, out_features=100, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=100, out_features=50, bias=True)
    (3): LeakyReLU(negative_slope=0.01)
    (4): Linear(in_features=50, out_features=10, bias=True)
  )
)

In [5]:
current_profile = torch.rand(3, 300)
encoded = current_encoder(current_profile)
print(f"{encoded = }")
print(f"{encoded.requires_grad = }")


encoded = tensor([[ 0.0613,  0.1177,  0.0679, -0.0781, -0.0600, -0.1421,  0.0401,  0.0310,
         -0.0717,  0.1091],
        [ 0.0642,  0.1166,  0.0718, -0.0801, -0.0579, -0.1412,  0.0397,  0.0290,
         -0.0726,  0.1109],
        [ 0.0627,  0.1174,  0.0705, -0.0794, -0.0590, -0.1400,  0.0387,  0.0319,
         -0.0713,  0.1071]], grad_fn=<AddmmBackward0>)
encoded.requires_grad = True


In [6]:
current_decoder = ConvolutionalDecoder(latent_dims=10 + 5 + 1, signal_dims=300)
current_decoder


ConvolutionalDecoder(
  (mlp): Sequential(
    (0): Linear(in_features=16, out_features=50, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=50, out_features=100, bias=True)
    (3): LeakyReLU(negative_slope=0.01)
    (4): Linear(in_features=100, out_features=1216, bias=True)
    (5): LeakyReLU(negative_slope=0.01)
  )
  (unflatten): Unflatten(dim=1, unflattened_size=(32, 38))
  (convnet): Sequential(
    (0): ConvTranspose1d(32, 16, kernel_size=(3,), stride=(2,), padding=(1,))
    (1): LeakyReLU(negative_slope=0.01)
    (2): ConvTranspose1d(16, 8, kernel_size=(3,), stride=(2,), padding=(1,), output_padding=(1,))
    (3): LeakyReLU(negative_slope=0.01)
    (4): ConvTranspose1d(8, 1, kernel_size=(3,), stride=(2,), padding=(1,), output_padding=(1,))
    (5): ReLU()
  )
)

In [7]:
current_decoder.parameters().__next__().requires_grad


True

In [8]:
latent = torch.rand(3, 10 + 5 + 1)
decoded = current_decoder(latent)
print(f"{decoded = }")
print(f"{decoded.requires_grad = }")


decoded = tensor([[0.4040, 0.3124, 0.4293, 0.2998, 0.4077, 0.3106, 0.4098, 0.3027, 0.4075,
         0.3148, 0.4050, 0.3123, 0.4012, 0.2957, 0.4189, 0.2879, 0.4110, 0.3018,
         0.4206, 0.3037, 0.4004, 0.2887, 0.4218, 0.2958, 0.4185, 0.2997, 0.4221,
         0.3146, 0.4034, 0.3138, 0.4158, 0.2905, 0.3998, 0.2834, 0.4127, 0.3031,
         0.4201, 0.2879, 0.4160, 0.3050, 0.4210, 0.2906, 0.4121, 0.3206, 0.3992,
         0.3003, 0.4255, 0.2852, 0.4010, 0.3077, 0.4220, 0.3199, 0.4241, 0.2944,
         0.4117, 0.2863, 0.4212, 0.2968, 0.4188, 0.2809, 0.4055, 0.2899, 0.4087,
         0.2975, 0.4105, 0.2957, 0.4150, 0.3075, 0.4007, 0.2993, 0.4166, 0.3020,
         0.4013, 0.2934, 0.4215, 0.2989, 0.4151, 0.2988, 0.4196, 0.3130, 0.4130,
         0.3046, 0.4178, 0.3044, 0.3997, 0.2983, 0.4170, 0.2820, 0.4122, 0.3025,
         0.4049, 0.3138, 0.4039, 0.2959, 0.4202, 0.2899, 0.4142, 0.3072, 0.4231,
         0.2966, 0.4047, 0.2887, 0.4104, 0.2863, 0.4183, 0.3091, 0.4203, 0.3005,
         0.4026, 0

In [9]:
formfactor_decoder = ConvolutionalDecoder(latent_dims=10 + 5 + 1, signal_dims=240)
formfactor_decoder


ConvolutionalDecoder(
  (mlp): Sequential(
    (0): Linear(in_features=16, out_features=50, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=50, out_features=100, bias=True)
    (3): LeakyReLU(negative_slope=0.01)
    (4): Linear(in_features=100, out_features=960, bias=True)
    (5): LeakyReLU(negative_slope=0.01)
  )
  (unflatten): Unflatten(dim=1, unflattened_size=(32, 30))
  (convnet): Sequential(
    (0): ConvTranspose1d(32, 16, kernel_size=(3,), stride=(2,), padding=(1,), output_padding=(1,))
    (1): LeakyReLU(negative_slope=0.01)
    (2): ConvTranspose1d(16, 8, kernel_size=(3,), stride=(2,), padding=(1,), output_padding=(1,))
    (3): LeakyReLU(negative_slope=0.01)
    (4): ConvTranspose1d(8, 1, kernel_size=(3,), stride=(2,), padding=(1,), output_padding=(1,))
    (5): ReLU()
  )
)

In [10]:
latent = torch.rand(3, 10 + 5 + 1)
decoded = formfactor_decoder(latent)
print(f"{decoded = }")
print(f"{decoded.requires_grad = }")


decoded = tensor([[0.6500, 0.6566, 0.6798, 0.6840, 0.6584, 0.6638, 0.6894, 0.6669, 0.6447,
         0.6617, 0.6758, 0.6618, 0.6412, 0.6610, 0.6857, 0.6701, 0.6433, 0.6655,
         0.6616, 0.6546, 0.6497, 0.6661, 0.6858, 0.6751, 0.6530, 0.6641, 0.6800,
         0.6695, 0.6500, 0.6615, 0.6831, 0.6685, 0.6417, 0.6533, 0.6770, 0.6748,
         0.6447, 0.6562, 0.6867, 0.6658, 0.6411, 0.6637, 0.6749, 0.6686, 0.6442,
         0.6599, 0.6766, 0.6672, 0.6522, 0.6700, 0.6746, 0.6551, 0.6403, 0.6611,
         0.6813, 0.6686, 0.6406, 0.6532, 0.6556, 0.6616, 0.6530, 0.6700, 0.6848,
         0.6671, 0.6483, 0.6608, 0.6796, 0.6731, 0.6431, 0.6562, 0.6801, 0.6658,
         0.6559, 0.6712, 0.6696, 0.6542, 0.6465, 0.6717, 0.6847, 0.6700, 0.6547,
         0.6740, 0.6843, 0.6648, 0.6535, 0.6711, 0.6843, 0.6687, 0.6459, 0.6653,
         0.6750, 0.6567, 0.6416, 0.6573, 0.6799, 0.6757, 0.6545, 0.6678, 0.6889,
         0.6782, 0.6441, 0.6481, 0.6731, 0.6704, 0.6478, 0.6607, 0.6784, 0.6680,
         0.6560, 0

In [11]:
generator = Generator()
generator


Generator(
  (formfactor_encoder): ConvolutionalEncoder(
    (convnet): Sequential(
      (0): Conv1d(1, 8, kernel_size=(3,), stride=(2,), padding=(1,))
      (1): LeakyReLU(negative_slope=0.01)
      (2): Conv1d(8, 16, kernel_size=(3,), stride=(2,), padding=(1,))
      (3): LeakyReLU(negative_slope=0.01)
      (4): Conv1d(16, 32, kernel_size=(3,), stride=(2,), padding=(1,))
      (5): LeakyReLU(negative_slope=0.01)
    )
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (mlp): Sequential(
      (0): Linear(in_features=960, out_features=100, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=100, out_features=50, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
      (4): Linear(in_features=50, out_features=10, bias=True)
    )
  )
  (current_decoder): ConvolutionalDecoder(
    (mlp): Sequential(
      (0): Linear(in_features=16, out_features=50, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=50, out_features=100

In [12]:
generator.parameters().__next__().requires_grad


True

In [13]:
formfactor = torch.rand(3, 240)
rf_settings = torch.rand(3, 5)
bunch_length = torch.rand(3, 1)

current_profile = generator(formfactor, rf_settings, bunch_length)
print(f"{current_profile.size() = }")
print(f"{current_profile.requires_grad = }")


current_profile.size() = torch.Size([3, 300])
current_profile.requires_grad = True


In [14]:
critic = Critic()
critic


Critic(
  (formfactor_encoder): ConvolutionalEncoder(
    (convnet): Sequential(
      (0): Conv1d(1, 8, kernel_size=(3,), stride=(2,), padding=(1,))
      (1): LeakyReLU(negative_slope=0.01)
      (2): Conv1d(8, 16, kernel_size=(3,), stride=(2,), padding=(1,))
      (3): LeakyReLU(negative_slope=0.01)
      (4): Conv1d(16, 32, kernel_size=(3,), stride=(2,), padding=(1,))
      (5): LeakyReLU(negative_slope=0.01)
    )
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (mlp): Sequential(
      (0): Linear(in_features=960, out_features=100, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=100, out_features=50, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
      (4): Linear(in_features=50, out_features=10, bias=True)
    )
  )
  (current_encoder): ConvolutionalEncoder(
    (convnet): Sequential(
      (0): Conv1d(1, 8, kernel_size=(3,), stride=(2,), padding=(1,))
      (1): LeakyReLU(negative_slope=0.01)
      (2): Conv1d(8, 16, kernel_size=(3,),

In [15]:
critique = critic(current_profile, formfactor, rf_settings, bunch_length)
print(f"{critique = }")
print(f"{critique.requires_grad = }")


critique = tensor([[-0.0524],
        [-0.0548],
        [-0.0673]], grad_fn=<AddmmBackward0>)
critique.requires_grad = True


In [16]:
dataset = EuXFELCurrentDataset(normalize=True)
dataset


<train_gan.EuXFELCurrentDataset at 0x2af237310>

In [17]:
len(dataset)


32000

In [18]:
(formfactor, rf_settings, bunch_length), current_profile = dataset[0]


In [19]:
formfactor


tensor([ 0.5686,  0.7723,  1.0779,  0.6725,  0.4086,  0.7040,  0.5354,  0.4655,
         0.3015,  0.7461,  0.6887,  0.3759,  0.4386,  0.4940,  0.5174,  0.7427,
         0.5371,  0.4569,  0.4125,  0.7526,  0.5118,  0.5952,  0.3537,  0.6484,
         0.6500,  0.5328,  0.5734,  0.4914,  0.6267,  0.4590,  0.4502,  0.7758,
         0.7688,  0.6264,  0.5553,  0.6135,  0.6125,  0.5770,  0.5955,  0.5749,
         0.5826,  0.6356,  0.6191,  0.5773,  0.5847,  0.6125,  0.5815,  0.5982,
         0.5877,  0.5858,  0.6251,  0.5709,  0.5883,  0.5329,  0.5751,  0.5578,
         0.5513,  0.5446,  0.5241,  0.4902,  0.5763,  0.5922,  0.4785,  0.5050,
         0.3155,  0.4155,  0.3747,  0.3794,  0.3606,  0.3266,  0.2784,  0.3489,
         0.3376,  0.3628,  0.2933,  0.1993,  0.2270,  0.2109,  0.2252,  0.1737,
         0.1151,  0.0970,  0.0518,  0.0090, -0.0781, -0.1075, -0.1452, -0.2126,
        -0.2592, -0.2709, -0.1551, -0.9425, -0.0910, -0.2910, -0.6669, -0.1924,
        -0.3548, -0.5469, -0.2502, -0.49

In [20]:
rf_settings


tensor([ 1.0250,  0.6563, -0.4842, -0.3788,  1.2869])

In [21]:
bunch_length


tensor([-0.6858])

In [22]:
current_profile


tensor([ 1.4061,  1.2965,  1.2059,  1.1369,  1.0940,  1.0753,  1.0729,  1.0814,
         1.0948,  1.1041,  1.1080,  1.1046,  1.0936,  1.0738,  1.0494,  1.0220,
         0.9934,  0.9643,  0.9370,  0.9125,  0.8915,  0.8740,  0.8595,  0.8470,
         0.8360,  0.8264,  0.8186,  0.8138,  0.8127,  0.8161,  0.8237,  0.8344,
         0.8471,  0.8604,  0.8733,  0.8852,  0.8958,  0.9049,  0.9122,  0.9176,
         0.9211,  0.9231,  0.9238,  0.9236,  0.9224,  0.9199,  0.9148,  0.9060,
         0.8928,  0.8749,  0.8543,  0.8332,  0.8142,  0.7988,  0.7874,  0.7787,
         0.7711,  0.7625,  0.7515,  0.7370,  0.7187,  0.6966,  0.6715,  0.6449,
         0.6183,  0.5929,  0.5700,  0.5502,  0.5336,  0.5199,  0.5086,  0.4988,
         0.4902,  0.4821,  0.4734,  0.4635,  0.4518,  0.4380,  0.4223,  0.4053,
         0.3874,  0.3691,  0.3507,  0.3328,  0.3160,  0.3007,  0.2868,  0.2730,
         0.2580,  0.2406,  0.2204,  0.1986,  0.1767,  0.1562,  0.1389,  0.1250,
         0.1143,  0.1059,  0.0984,  0.08

In [27]:
data_module = EuXFELCurrentDataModule(batch_size=64, normalize=True)
data_module


<train_gan.EuXFELCurrentDataModule at 0x2af237490>

In [28]:
data_module.setup(stage="train")


In [29]:
data_loader = data_module.train_dataloader()
data_loader


<torch.utils.data.dataloader.DataLoader at 0x2af2376d0>

In [30]:
len(data_loader)


300