In [1]:
import torch

from train_gan import (
    Critic,
    EuXFELCurrentDataModule,
    EuXFELCurrentDataset,
    Generator,
    ConvolutionalDecoder,
    ConvolutionalEncoder,
)


In [2]:
formfactor_encoder = ConvolutionalEncoder(signal_dims=240, latent_dims=10)
formfactor_encoder


ConvolutionalEncoder(
  (convnet): Sequential(
    (0): Conv1d(1, 8, kernel_size=(3,), stride=(2,), padding=(1,))
    (1): LeakyReLU(negative_slope=0.01)
    (2): Conv1d(8, 16, kernel_size=(3,), stride=(2,), padding=(1,))
    (3): LeakyReLU(negative_slope=0.01)
    (4): Conv1d(16, 32, kernel_size=(3,), stride=(2,), padding=(1,))
    (5): LeakyReLU(negative_slope=0.01)
  )
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (mlp): Sequential(
    (0): Linear(in_features=960, out_features=100, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=100, out_features=50, bias=True)
    (3): LeakyReLU(negative_slope=0.01)
    (4): Linear(in_features=50, out_features=10, bias=True)
  )
)

In [3]:
formfactor = torch.rand(3, 240)
encoded = formfactor_encoder(formfactor)
print(f"{encoded = }")
print(f"{encoded.requires_grad = }")


encoded = tensor([[ 0.0735,  0.0666,  0.1148, -0.0146,  0.0067, -0.0111,  0.1087, -0.1036,
          0.1379,  0.0438],
        [ 0.0734,  0.0672,  0.1120, -0.0131,  0.0072, -0.0098,  0.1079, -0.1054,
          0.1385,  0.0463],
        [ 0.0733,  0.0632,  0.1110, -0.0153,  0.0064, -0.0092,  0.1070, -0.1028,
          0.1366,  0.0457]], grad_fn=<AddmmBackward0>)
encoded.requires_grad = True


In [4]:
current_encoder = ConvolutionalEncoder(signal_dims=300, latent_dims=10)
current_encoder


ConvolutionalEncoder(
  (convnet): Sequential(
    (0): Conv1d(1, 8, kernel_size=(3,), stride=(2,), padding=(1,))
    (1): LeakyReLU(negative_slope=0.01)
    (2): Conv1d(8, 16, kernel_size=(3,), stride=(2,), padding=(1,))
    (3): LeakyReLU(negative_slope=0.01)
    (4): Conv1d(16, 32, kernel_size=(3,), stride=(2,), padding=(1,))
    (5): LeakyReLU(negative_slope=0.01)
  )
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (mlp): Sequential(
    (0): Linear(in_features=1216, out_features=100, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=100, out_features=50, bias=True)
    (3): LeakyReLU(negative_slope=0.01)
    (4): Linear(in_features=50, out_features=10, bias=True)
  )
)

In [5]:
current_profile = torch.rand(3, 300)
encoded = current_encoder(current_profile)
print(f"{encoded = }")
print(f"{encoded.requires_grad = }")


encoded = tensor([[ 0.0178, -0.0527,  0.0861, -0.0232,  0.0782, -0.0730, -0.1707, -0.0686,
         -0.1039, -0.0677],
        [ 0.0183, -0.0512,  0.0877, -0.0220,  0.0779, -0.0746, -0.1730, -0.0685,
         -0.1057, -0.0651],
        [ 0.0170, -0.0521,  0.0877, -0.0247,  0.0781, -0.0722, -0.1686, -0.0654,
         -0.1055, -0.0672]], grad_fn=<AddmmBackward0>)
encoded.requires_grad = True


In [6]:
current_decoder = ConvolutionalDecoder(latent_dims=10 + 5 + 1, signal_dims=300)
current_decoder


ConvolutionalDecoder(
  (mlp): Sequential(
    (0): Linear(in_features=16, out_features=50, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=50, out_features=100, bias=True)
    (3): LeakyReLU(negative_slope=0.01)
    (4): Linear(in_features=100, out_features=1216, bias=True)
    (5): LeakyReLU(negative_slope=0.01)
  )
  (unflatten): Unflatten(dim=1, unflattened_size=(32, 38))
  (convnet): Sequential(
    (0): ConvTranspose1d(32, 16, kernel_size=(3,), stride=(2,), padding=(1,))
    (1): LeakyReLU(negative_slope=0.01)
    (2): ConvTranspose1d(16, 8, kernel_size=(3,), stride=(2,), padding=(1,), output_padding=(1,))
    (3): LeakyReLU(negative_slope=0.01)
    (4): ConvTranspose1d(8, 1, kernel_size=(3,), stride=(2,), padding=(1,), output_padding=(1,))
    (5): ReLU()
  )
)

In [7]:
current_decoder.parameters().__next__().requires_grad


True

In [8]:
latent = torch.rand(3, 10 + 5 + 1)
decoded = current_decoder(latent)
print(f"{decoded = }")
print(f"{decoded.requires_grad = }")


decoded = tensor([[0.0000, 0.0349, 0.0000, 0.0194, 0.0000, 0.0282, 0.0000, 0.0298, 0.0000,
         0.0380, 0.0000, 0.0375, 0.0000, 0.0315, 0.0000, 0.0336, 0.0000, 0.0274,
         0.0000, 0.0502, 0.0000, 0.0182, 0.0000, 0.0306, 0.0000, 0.0228, 0.0000,
         0.0276, 0.0000, 0.0306, 0.0000, 0.0108, 0.0000, 0.0286, 0.0000, 0.0223,
         0.0000, 0.0272, 0.0000, 0.0438, 0.0000, 0.0344, 0.0000, 0.0462, 0.0000,
         0.0307, 0.0000, 0.0172, 0.0000, 0.0312, 0.0000, 0.0508, 0.0000, 0.0385,
         0.0000, 0.0413, 0.0000, 0.0497, 0.0000, 0.0648, 0.0000, 0.0441, 0.0000,
         0.0163, 0.0000, 0.0232, 0.0000, 0.0155, 0.0000, 0.0337, 0.0000, 0.0220,
         0.0000, 0.0176, 0.0000, 0.0289, 0.0000, 0.0290, 0.0000, 0.0356, 0.0000,
         0.0181, 0.0000, 0.0204, 0.0000, 0.0208, 0.0000, 0.0299, 0.0000, 0.0115,
         0.0000, 0.0244, 0.0000, 0.0150, 0.0000, 0.0333, 0.0000, 0.0376, 0.0000,
         0.0253, 0.0000, 0.0178, 0.0000, 0.0024, 0.0000, 0.0226, 0.0000, 0.0184,
         0.0000, 0

In [9]:
formfactor_decoder = ConvolutionalDecoder(latent_dims=10 + 5 + 1, signal_dims=240)
formfactor_decoder


ConvolutionalDecoder(
  (mlp): Sequential(
    (0): Linear(in_features=16, out_features=50, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=50, out_features=100, bias=True)
    (3): LeakyReLU(negative_slope=0.01)
    (4): Linear(in_features=100, out_features=960, bias=True)
    (5): LeakyReLU(negative_slope=0.01)
  )
  (unflatten): Unflatten(dim=1, unflattened_size=(32, 30))
  (convnet): Sequential(
    (0): ConvTranspose1d(32, 16, kernel_size=(3,), stride=(2,), padding=(1,), output_padding=(1,))
    (1): LeakyReLU(negative_slope=0.01)
    (2): ConvTranspose1d(16, 8, kernel_size=(3,), stride=(2,), padding=(1,), output_padding=(1,))
    (3): LeakyReLU(negative_slope=0.01)
    (4): ConvTranspose1d(8, 1, kernel_size=(3,), stride=(2,), padding=(1,), output_padding=(1,))
    (5): ReLU()
  )
)

In [10]:
latent = torch.rand(3, 10 + 5 + 1)
decoded = formfactor_decoder(latent)
print(f"{decoded = }")
print(f"{decoded.requires_grad = }")


decoded = tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 

In [11]:
generator = Generator()
generator


Generator(
  (formfactor_encoder): ConvolutionalEncoder(
    (convnet): Sequential(
      (0): Conv1d(1, 8, kernel_size=(3,), stride=(2,), padding=(1,))
      (1): LeakyReLU(negative_slope=0.01)
      (2): Conv1d(8, 16, kernel_size=(3,), stride=(2,), padding=(1,))
      (3): LeakyReLU(negative_slope=0.01)
      (4): Conv1d(16, 32, kernel_size=(3,), stride=(2,), padding=(1,))
      (5): LeakyReLU(negative_slope=0.01)
    )
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (mlp): Sequential(
      (0): Linear(in_features=960, out_features=100, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=100, out_features=50, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
      (4): Linear(in_features=50, out_features=10, bias=True)
    )
  )
  (current_decoder): ConvolutionalDecoder(
    (mlp): Sequential(
      (0): Linear(in_features=16, out_features=50, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=50, out_features=100

In [12]:
generator.parameters().__next__().requires_grad


True

In [13]:
formfactor = torch.rand(3, 240)
rf_settings = torch.rand(3, 5)
bunch_length = torch.rand(3, 1)

current_profile = generator(formfactor, rf_settings, bunch_length)
print(f"{current_profile.size() = }")
print(f"{current_profile.requires_grad = }")


current_profile.size() = torch.Size([3, 300])
current_profile.requires_grad = True


In [14]:
critic = Critic()
critic


Critic(
  (formfactor_encoder): ConvolutionalEncoder(
    (convnet): Sequential(
      (0): Conv1d(1, 8, kernel_size=(3,), stride=(2,), padding=(1,))
      (1): LeakyReLU(negative_slope=0.01)
      (2): Conv1d(8, 16, kernel_size=(3,), stride=(2,), padding=(1,))
      (3): LeakyReLU(negative_slope=0.01)
      (4): Conv1d(16, 32, kernel_size=(3,), stride=(2,), padding=(1,))
      (5): LeakyReLU(negative_slope=0.01)
    )
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (mlp): Sequential(
      (0): Linear(in_features=960, out_features=100, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=100, out_features=50, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
      (4): Linear(in_features=50, out_features=10, bias=True)
    )
  )
  (current_encoder): ConvolutionalEncoder(
    (convnet): Sequential(
      (0): Conv1d(1, 8, kernel_size=(3,), stride=(2,), padding=(1,))
      (1): LeakyReLU(negative_slope=0.01)
      (2): Conv1d(8, 16, kernel_size=(3,),

In [15]:
critique = critic(current_profile, formfactor, rf_settings, bunch_length)
print(f"{critique = }")
print(f"{critique.requires_grad = }")


critique = tensor([[-0.0222],
        [-0.0188],
        [-0.0079]], grad_fn=<AddmmBackward0>)
critique.requires_grad = True


In [16]:
dataset = EuXFELCurrentDataset(normalize=True)
dataset


<train_gan.EuXFELCurrentDataset at 0x2aa32e430>

In [17]:
len(dataset)


25600

In [18]:
(formfactor, rf_settings, bunch_length), current_profile = dataset[0]


In [19]:
formfactor


tensor([ 0.8967, -0.0496,  0.6493,  0.6142,  0.6210,  0.3280,  0.4514,  0.7230,
         0.5122,  0.4343,  0.7586,  0.7439,  0.7516,  0.7016,  0.6730, -0.3678,
         0.7603,  0.7358,  0.7098,  0.7879,  0.7730,  0.9261,  0.8817,  0.8326,
         0.8834,  0.8499,  0.8204,  0.8329,  0.8822,  0.9108,  0.8027,  1.0196,
         0.9109,  0.6966,  0.7671,  0.9020,  0.8463,  0.9784,  0.8875,  0.9628,
         1.0164,  1.0460,  1.0049,  1.0713,  1.0758,  1.1035,  1.0718,  1.0805,
         1.1297,  1.1424,  1.1248,  1.1508,  1.1933,  1.2206,  1.2327,  1.2506,
         1.2511,  1.2965,  1.3434,  1.3815,  1.5425,  1.1233,  1.4732,  1.3150,
         1.4260,  1.4542,  1.3917,  1.4061,  1.4489,  1.4558,  1.4331,  1.4687,
         1.5044,  1.5263,  1.4788,  1.5456,  1.5517,  1.5728,  1.6156,  1.6282,
         1.6280,  1.6681,  1.6934,  1.7095,  1.7405,  1.7749,  1.7983,  1.8291,
         1.8542,  1.8679,  1.7163,  1.8633,  1.7904,  1.8032,  1.7841,  1.8590,
         1.8147,  1.8813,  1.8448,  1.89

In [20]:
rf_settings


tensor([ 0.3616,  1.7155, -1.3814, -0.2878, -0.9114])

In [21]:
bunch_length


tensor([-1.0020])

In [22]:
current_profile


tensor([-0.4532, -0.4376, -0.4221, -0.4059, -0.3925, -0.3846, -0.3829, -0.3856,
        -0.3896, -0.3934, -0.3964, -0.3985, -0.3999, -0.4005, -0.4007, -0.4005,
        -0.4002, -0.4000, -0.3998, -0.3997, -0.3998, -0.4000, -0.4004, -0.4008,
        -0.4013, -0.4015, -0.4019, -0.4037, -0.4074, -0.4126, -0.4183, -0.4243,
        -0.4309, -0.4379, -0.4454, -0.4532, -0.4611, -0.4690, -0.4769, -0.4847,
        -0.4922, -0.4992, -0.5053, -0.5109, -0.5174, -0.5251, -0.5326, -0.5390,
        -0.5442, -0.5495, -0.5561, -0.5640, -0.5720, -0.5795, -0.5863, -0.5925,
        -0.5986, -0.6056, -0.6132, -0.6202, -0.6254, -0.6270, -0.6191, -0.5927,
        -0.5399, -0.4538, -0.3326, -0.1847, -0.0210,  0.1518,  0.3331,  0.5266,
         0.7321,  0.9510,  1.1886,  1.4529,  1.7598,  2.1348,  2.5822,  3.0652,
         3.4982,  3.7927,  3.9313,  3.9726,  4.0022,  4.0659,  4.1547,  4.2426,
         4.3249,  4.3937,  4.4215,  4.3959,  4.3457,  4.2917,  4.2158,  4.0946,
         3.9514,  3.8225,  3.7079,  3.58

In [23]:
data_module = EuXFELCurrentDataModule(batch_size=64, normalize=True)
data_module


<train_gan.EuXFELCurrentDataModule at 0x2aa32e190>

In [25]:
data_module.setup(stage="fit")


In [28]:
data_loader = data_module.train_dataloader()
data_loader, len(data_loader)


(<torch.utils.data.dataloader.DataLoader at 0x2aa32e850>, 400)

In [29]:
data_loader = data_module.val_dataloader()
data_loader, len(data_loader)


(<torch.utils.data.dataloader.DataLoader at 0x2aa32e130>, 50)

In [30]:
data_loader = data_module.test_dataloader()
data_loader, len(data_loader)


(<torch.utils.data.dataloader.DataLoader at 0x2aa32e3d0>, 50)