In [1]:
import torch

from train_gan import (
    Critic,
    EuXFELCurrentDataModule,
    EuXFELCurrentDataset,
    Generator,
    ConvolutionalDecoder,
    ConvolutionalEncoder,
)


In [2]:
formfactor_encoder = ConvolutionalEncoder(signal_dims=240, latent_dims=10)
formfactor_encoder


ConvolutionalEncoder(
  (convnet): Sequential(
    (0): Conv1d(1, 8, kernel_size=(3,), stride=(2,), padding=(1,))
    (1): LeakyReLU(negative_slope=0.01)
    (2): Conv1d(8, 16, kernel_size=(3,), stride=(2,), padding=(1,))
    (3): LeakyReLU(negative_slope=0.01)
    (4): Conv1d(16, 32, kernel_size=(3,), stride=(2,), padding=(1,))
    (5): LeakyReLU(negative_slope=0.01)
  )
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (mlp): Sequential(
    (0): Linear(in_features=960, out_features=100, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=100, out_features=50, bias=True)
    (3): LeakyReLU(negative_slope=0.01)
    (4): Linear(in_features=50, out_features=10, bias=True)
  )
)

In [3]:
formfactor = torch.rand(3, 240)
encoded = formfactor_encoder(formfactor)
print(f"{encoded = }")
print(f"{encoded.requires_grad = }")


encoded = tensor([[-0.1044, -0.1262, -0.0411, -0.0960, -0.0570, -0.1411,  0.0819,  0.0235,
         -0.0648, -0.0935],
        [-0.1040, -0.1261, -0.0408, -0.0961, -0.0574, -0.1407,  0.0816,  0.0229,
         -0.0650, -0.0933],
        [-0.1048, -0.1260, -0.0408, -0.0958, -0.0573, -0.1414,  0.0820,  0.0226,
         -0.0650, -0.0936]], grad_fn=<AddmmBackward0>)
encoded.requires_grad = True


In [4]:
current_encoder = ConvolutionalEncoder(signal_dims=300, latent_dims=10)
current_encoder


ConvolutionalEncoder(
  (convnet): Sequential(
    (0): Conv1d(1, 8, kernel_size=(3,), stride=(2,), padding=(1,))
    (1): LeakyReLU(negative_slope=0.01)
    (2): Conv1d(8, 16, kernel_size=(3,), stride=(2,), padding=(1,))
    (3): LeakyReLU(negative_slope=0.01)
    (4): Conv1d(16, 32, kernel_size=(3,), stride=(2,), padding=(1,))
    (5): LeakyReLU(negative_slope=0.01)
  )
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (mlp): Sequential(
    (0): Linear(in_features=1216, out_features=100, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=100, out_features=50, bias=True)
    (3): LeakyReLU(negative_slope=0.01)
    (4): Linear(in_features=50, out_features=10, bias=True)
  )
)

In [5]:
current_profile = torch.rand(3, 300)
encoded = current_encoder(current_profile)
print(f"{encoded = }")
print(f"{encoded.requires_grad = }")


encoded = tensor([[ 0.0040, -0.0270, -0.0408, -0.1242,  0.0279, -0.1028, -0.1106,  0.1344,
         -0.0174, -0.0546],
        [ 0.0033, -0.0286, -0.0406, -0.1248,  0.0279, -0.1040, -0.1130,  0.1342,
         -0.0175, -0.0550],
        [ 0.0033, -0.0276, -0.0395, -0.1246,  0.0304, -0.1035, -0.1098,  0.1344,
         -0.0161, -0.0537]], grad_fn=<AddmmBackward0>)
encoded.requires_grad = True


In [6]:
current_decoder = ConvolutionalDecoder(latent_dims=10 + 5 + 1, signal_dims=300)
current_decoder


ConvolutionalDecoder(
  (mlp): Sequential(
    (0): Linear(in_features=16, out_features=50, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=50, out_features=100, bias=True)
    (3): LeakyReLU(negative_slope=0.01)
    (4): Linear(in_features=100, out_features=1216, bias=True)
    (5): LeakyReLU(negative_slope=0.01)
  )
  (unflatten): Unflatten(dim=1, unflattened_size=(32, 38))
  (convnet): Sequential(
    (0): ConvTranspose1d(32, 16, kernel_size=(3,), stride=(2,), padding=(1,))
    (1): LeakyReLU(negative_slope=0.01)
    (2): ConvTranspose1d(16, 8, kernel_size=(3,), stride=(2,), padding=(1,), output_padding=(1,))
    (3): LeakyReLU(negative_slope=0.01)
    (4): ConvTranspose1d(8, 1, kernel_size=(3,), stride=(2,), padding=(1,), output_padding=(1,))
    (5): ReLU()
  )
)

In [7]:
current_decoder.parameters().__next__().requires_grad


True

In [8]:
latent = torch.rand(3, 10 + 5 + 1)
decoded = current_decoder(latent)
print(f"{decoded = }")
print(f"{decoded.requires_grad = }")


decoded = tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 

In [9]:
formfactor_decoder = ConvolutionalDecoder(latent_dims=10 + 5 + 1, signal_dims=240)
formfactor_decoder


ConvolutionalDecoder(
  (mlp): Sequential(
    (0): Linear(in_features=16, out_features=50, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=50, out_features=100, bias=True)
    (3): LeakyReLU(negative_slope=0.01)
    (4): Linear(in_features=100, out_features=960, bias=True)
    (5): LeakyReLU(negative_slope=0.01)
  )
  (unflatten): Unflatten(dim=1, unflattened_size=(32, 30))
  (convnet): Sequential(
    (0): ConvTranspose1d(32, 16, kernel_size=(3,), stride=(2,), padding=(1,), output_padding=(1,))
    (1): LeakyReLU(negative_slope=0.01)
    (2): ConvTranspose1d(16, 8, kernel_size=(3,), stride=(2,), padding=(1,), output_padding=(1,))
    (3): LeakyReLU(negative_slope=0.01)
    (4): ConvTranspose1d(8, 1, kernel_size=(3,), stride=(2,), padding=(1,), output_padding=(1,))
    (5): ReLU()
  )
)

In [10]:
latent = torch.rand(3, 10 + 5 + 1)
decoded = formfactor_decoder(latent)
print(f"{decoded = }")
print(f"{decoded.requires_grad = }")


decoded = tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 

In [11]:
generator = Generator()
generator


Generator(
  (formfactor_encoder): ConvolutionalEncoder(
    (convnet): Sequential(
      (0): Conv1d(1, 8, kernel_size=(3,), stride=(2,), padding=(1,))
      (1): LeakyReLU(negative_slope=0.01)
      (2): Conv1d(8, 16, kernel_size=(3,), stride=(2,), padding=(1,))
      (3): LeakyReLU(negative_slope=0.01)
      (4): Conv1d(16, 32, kernel_size=(3,), stride=(2,), padding=(1,))
      (5): LeakyReLU(negative_slope=0.01)
    )
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (mlp): Sequential(
      (0): Linear(in_features=960, out_features=100, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=100, out_features=50, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
      (4): Linear(in_features=50, out_features=10, bias=True)
    )
  )
  (scalar_spectral_combine_mlp): Sequential(
    (0): Linear(in_features=15, out_features=50, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=50, out_features=20, bias=True)
    (3): LeakyR

In [12]:
generator.parameters().__next__().requires_grad


True

In [13]:
rf_settings = torch.rand(3, 5)
formfactor = torch.rand(3, 240)

current_profile, bunch_length = generator(rf_settings, formfactor)
print(f"{current_profile.size() = }")
print(f"{current_profile.requires_grad = }")
print(f"{bunch_length.size() = }")
print(f"{bunch_length.requires_grad = }")


current_profile.size() = torch.Size([3, 300])
current_profile.requires_grad = True
bunch_length.size() = torch.Size([3, 1])
bunch_length.requires_grad = True


In [14]:
critic = Critic()
critic


Critic(
  (formfactor_encoder): ConvolutionalEncoder(
    (convnet): Sequential(
      (0): Conv1d(1, 8, kernel_size=(3,), stride=(2,), padding=(1,))
      (1): LeakyReLU(negative_slope=0.01)
      (2): Conv1d(8, 16, kernel_size=(3,), stride=(2,), padding=(1,))
      (3): LeakyReLU(negative_slope=0.01)
      (4): Conv1d(16, 32, kernel_size=(3,), stride=(2,), padding=(1,))
      (5): LeakyReLU(negative_slope=0.01)
    )
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (mlp): Sequential(
      (0): Linear(in_features=960, out_features=100, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=100, out_features=50, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
      (4): Linear(in_features=50, out_features=10, bias=True)
    )
  )
  (current_encoder): ConvolutionalEncoder(
    (convnet): Sequential(
      (0): Conv1d(1, 8, kernel_size=(3,), stride=(2,), padding=(1,))
      (1): LeakyReLU(negative_slope=0.01)
      (2): Conv1d(8, 16, kernel_size=(3,),

In [15]:
critique = critic(rf_settings, formfactor, current_profile, bunch_length)
print(f"{critique = }")
print(f"{critique.requires_grad = }")


critique = tensor([[-0.1880],
        [-0.1769],
        [-0.1727]], grad_fn=<AddmmBackward0>)
critique.requires_grad = True


In [16]:
dataset = EuXFELCurrentDataset(normalize=True)
dataset


<train_gan.EuXFELCurrentDataset at 0x28fd138e0>

In [17]:
len(dataset)


25600

In [18]:
(rf_settings, formfactor), (current_profile, bunch_length) = dataset[0]


In [19]:
formfactor


tensor([ 0.1088,  0.8982,  0.7778, -0.3842,  0.3934,  1.1806,  0.3513,  0.5333,
         0.4273,  0.4169,  0.9124,  0.5842,  0.5516,  0.8963,  0.8173, -0.0905,
         0.8763,  0.7232,  0.8295,  0.7835,  0.7825,  0.8283,  0.8010,  0.8343,
         0.7814,  0.8849,  0.8978,  0.8681,  0.8983,  0.9119,  0.9508,  0.6962,
         0.9910,  0.9971,  0.8135,  1.0166,  0.8453,  1.0088,  0.9617,  1.0021,
         1.0230,  0.9885,  1.0004,  1.0561,  1.0639,  1.0924,  1.0754,  1.0793,
         1.0991,  1.0668,  1.1638,  1.1874,  1.1758,  1.2214,  1.2338,  1.2624,
         1.3026,  1.2957,  1.3394,  1.3627,  1.1388,  1.2419,  1.4109,  1.3788,
         1.3503,  1.4213,  1.3817,  1.3550,  1.4236,  1.4270,  1.4571,  1.5076,
         1.4982,  1.5231,  1.5010,  1.5507,  1.5560,  1.5790,  1.5979,  1.6012,
         1.6426,  1.6666,  1.6868,  1.7104,  1.7388,  1.7728,  1.7886,  1.8227,
         1.8537,  1.8489,  1.6903,  1.8300,  1.8327,  1.8037,  1.8014,  1.8423,
         1.8530,  1.8607,  1.9009,  1.89

In [20]:
rf_settings


tensor([ 0.3616,  1.7155, -1.3814, -0.2878, -0.9114])

In [21]:
bunch_length


tensor([-1.0020])

In [22]:
current_profile


tensor([-0.4532, -0.4376, -0.4221, -0.4059, -0.3925, -0.3846, -0.3829, -0.3856,
        -0.3896, -0.3934, -0.3964, -0.3985, -0.3999, -0.4005, -0.4007, -0.4005,
        -0.4002, -0.4000, -0.3998, -0.3997, -0.3998, -0.4000, -0.4004, -0.4008,
        -0.4013, -0.4015, -0.4019, -0.4037, -0.4074, -0.4126, -0.4183, -0.4243,
        -0.4309, -0.4379, -0.4454, -0.4532, -0.4611, -0.4690, -0.4769, -0.4847,
        -0.4922, -0.4992, -0.5053, -0.5109, -0.5174, -0.5251, -0.5326, -0.5390,
        -0.5442, -0.5495, -0.5561, -0.5640, -0.5720, -0.5795, -0.5863, -0.5925,
        -0.5986, -0.6056, -0.6132, -0.6202, -0.6254, -0.6270, -0.6191, -0.5927,
        -0.5399, -0.4538, -0.3326, -0.1847, -0.0210,  0.1518,  0.3331,  0.5266,
         0.7321,  0.9510,  1.1886,  1.4529,  1.7598,  2.1348,  2.5822,  3.0652,
         3.4982,  3.7927,  3.9313,  3.9726,  4.0022,  4.0659,  4.1547,  4.2426,
         4.3249,  4.3937,  4.4215,  4.3959,  4.3457,  4.2917,  4.2158,  4.0946,
         3.9514,  3.8225,  3.7079,  3.58

In [23]:
data_module = EuXFELCurrentDataModule(batch_size=64, normalize=True)
data_module


<train_gan.EuXFELCurrentDataModule at 0x28fd13f40>

In [24]:
data_module.setup(stage="fit")


In [25]:
data_loader = data_module.train_dataloader()
data_loader, len(data_loader)


(<torch.utils.data.dataloader.DataLoader at 0x28fd13a30>, 400)

In [26]:
data_loader = data_module.val_dataloader()
data_loader, len(data_loader)


(<torch.utils.data.dataloader.DataLoader at 0x28fd13e50>, 50)

In [27]:
data_loader = data_module.test_dataloader()
data_loader, len(data_loader)


(<torch.utils.data.dataloader.DataLoader at 0x28fd134f0>, 50)