In [1]:
import torch

from train_gan import (
    Critic,
    EuXFELCurrentDataModule,
    EuXFELCurrentDataset,
    Generator,
    ConvolutionalDecoder,
    ConvolutionalEncoder,
)


In [2]:
formfactor_encoder = ConvolutionalEncoder(signal_dims=240, latent_dims=10)
formfactor_encoder


ConvolutionalEncoder(
  (convnet): Sequential(
    (0): Conv1d(1, 8, kernel_size=(3,), stride=(2,), padding=(1,))
    (1): LeakyReLU(negative_slope=0.01)
    (2): Conv1d(8, 16, kernel_size=(3,), stride=(2,), padding=(1,))
    (3): LeakyReLU(negative_slope=0.01)
    (4): Conv1d(16, 32, kernel_size=(3,), stride=(2,), padding=(1,))
    (5): LeakyReLU(negative_slope=0.01)
  )
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (mlp): Sequential(
    (0): Linear(in_features=960, out_features=100, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=100, out_features=50, bias=True)
    (3): LeakyReLU(negative_slope=0.01)
    (4): Linear(in_features=50, out_features=10, bias=True)
  )
)

In [3]:
formfactor = torch.rand(3, 240)
encoded = formfactor_encoder(formfactor)
print(f"{encoded = }")
print(f"{encoded.requires_grad = }")


encoded = tensor([[-0.0639,  0.0910, -0.0980, -0.0625, -0.0759, -0.0094,  0.0636, -0.1240,
         -0.0454,  0.0487],
        [-0.0655,  0.0900, -0.0978, -0.0640, -0.0763, -0.0097,  0.0666, -0.1220,
         -0.0442,  0.0495],
        [-0.0644,  0.0895, -0.0972, -0.0617, -0.0767, -0.0083,  0.0657, -0.1216,
         -0.0453,  0.0498]], grad_fn=<AddmmBackward0>)
encoded.requires_grad = True


In [4]:
current_encoder = ConvolutionalEncoder(signal_dims=300, latent_dims=10)
current_encoder


ConvolutionalEncoder(
  (convnet): Sequential(
    (0): Conv1d(1, 8, kernel_size=(3,), stride=(2,), padding=(1,))
    (1): LeakyReLU(negative_slope=0.01)
    (2): Conv1d(8, 16, kernel_size=(3,), stride=(2,), padding=(1,))
    (3): LeakyReLU(negative_slope=0.01)
    (4): Conv1d(16, 32, kernel_size=(3,), stride=(2,), padding=(1,))
    (5): LeakyReLU(negative_slope=0.01)
  )
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (mlp): Sequential(
    (0): Linear(in_features=1216, out_features=100, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=100, out_features=50, bias=True)
    (3): LeakyReLU(negative_slope=0.01)
    (4): Linear(in_features=50, out_features=10, bias=True)
  )
)

In [5]:
current_profile = torch.rand(3, 300)
encoded = current_encoder(current_profile)
print(f"{encoded = }")
print(f"{encoded.requires_grad = }")


encoded = tensor([[ 0.0496, -0.0315,  0.0538,  0.0646,  0.1436,  0.1433,  0.1666,  0.0769,
         -0.1351,  0.1059],
        [ 0.0520, -0.0311,  0.0539,  0.0665,  0.1418,  0.1418,  0.1666,  0.0758,
         -0.1379,  0.1068],
        [ 0.0514, -0.0323,  0.0541,  0.0645,  0.1425,  0.1420,  0.1664,  0.0743,
         -0.1361,  0.1067]], grad_fn=<AddmmBackward0>)
encoded.requires_grad = True


In [6]:
current_decoder = ConvolutionalDecoder(latent_dims=10 + 5 + 1, signal_dims=300)
current_decoder


ConvolutionalDecoder(
  (mlp): Sequential(
    (0): Linear(in_features=16, out_features=50, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=50, out_features=100, bias=True)
    (3): LeakyReLU(negative_slope=0.01)
    (4): Linear(in_features=100, out_features=1216, bias=True)
    (5): LeakyReLU(negative_slope=0.01)
  )
  (unflatten): Unflatten(dim=1, unflattened_size=(32, 38))
  (convnet): Sequential(
    (0): ConvTranspose1d(32, 16, kernel_size=(3,), stride=(2,), padding=(1,))
    (1): LeakyReLU(negative_slope=0.01)
    (2): ConvTranspose1d(16, 8, kernel_size=(3,), stride=(2,), padding=(1,), output_padding=(1,))
    (3): LeakyReLU(negative_slope=0.01)
    (4): ConvTranspose1d(8, 1, kernel_size=(3,), stride=(2,), padding=(1,), output_padding=(1,))
    (5): ReLU()
  )
)

In [7]:
current_decoder.parameters().__next__().requires_grad


True

In [8]:
latent = torch.rand(3, 10 + 5 + 1)
decoded = current_decoder(latent)
print(f"{decoded = }")
print(f"{decoded.requires_grad = }")


decoded = tensor([[0.4901, 0.5386, 0.5056, 0.5496, 0.5053, 0.5436, 0.5077, 0.5483, 0.4996,
         0.5443, 0.4973, 0.5507, 0.4987, 0.5418, 0.4984, 0.5412, 0.4960, 0.5370,
         0.4984, 0.5440, 0.4938, 0.5376, 0.5007, 0.5494, 0.4918, 0.5414, 0.4980,
         0.5513, 0.4957, 0.5401, 0.5017, 0.5492, 0.4946, 0.5358, 0.5051, 0.5509,
         0.5034, 0.5364, 0.5002, 0.5446, 0.4944, 0.5360, 0.4972, 0.5496, 0.4885,
         0.5240, 0.5034, 0.5447, 0.4954, 0.5389, 0.5054, 0.5434, 0.4990, 0.5363,
         0.5025, 0.5404, 0.5058, 0.5441, 0.4982, 0.5571, 0.4864, 0.5300, 0.5043,
         0.5475, 0.5003, 0.5396, 0.4953, 0.5387, 0.4833, 0.5311, 0.5021, 0.5479,
         0.4923, 0.5338, 0.5000, 0.5448, 0.4907, 0.5343, 0.5140, 0.5462, 0.4996,
         0.5387, 0.5063, 0.5496, 0.5027, 0.5333, 0.5078, 0.5531, 0.4980, 0.5320,
         0.5071, 0.5538, 0.4945, 0.5338, 0.5088, 0.5563, 0.4963, 0.5424, 0.5006,
         0.5466, 0.4983, 0.5451, 0.4974, 0.5478, 0.4952, 0.5452, 0.4977, 0.5469,
         0.4956, 0

In [9]:
formfactor_decoder = ConvolutionalDecoder(latent_dims=10 + 5 + 1, signal_dims=240)
formfactor_decoder


ConvolutionalDecoder(
  (mlp): Sequential(
    (0): Linear(in_features=16, out_features=50, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=50, out_features=100, bias=True)
    (3): LeakyReLU(negative_slope=0.01)
    (4): Linear(in_features=100, out_features=960, bias=True)
    (5): LeakyReLU(negative_slope=0.01)
  )
  (unflatten): Unflatten(dim=1, unflattened_size=(32, 30))
  (convnet): Sequential(
    (0): ConvTranspose1d(32, 16, kernel_size=(3,), stride=(2,), padding=(1,), output_padding=(1,))
    (1): LeakyReLU(negative_slope=0.01)
    (2): ConvTranspose1d(16, 8, kernel_size=(3,), stride=(2,), padding=(1,), output_padding=(1,))
    (3): LeakyReLU(negative_slope=0.01)
    (4): ConvTranspose1d(8, 1, kernel_size=(3,), stride=(2,), padding=(1,), output_padding=(1,))
    (5): ReLU()
  )
)

In [10]:
latent = torch.rand(3, 10 + 5 + 1)
decoded = formfactor_decoder(latent)
print(f"{decoded = }")
print(f"{decoded.requires_grad = }")


decoded = tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 

In [11]:
generator = Generator()
generator


Generator(
  (formfactor_encoder): ConvolutionalEncoder(
    (convnet): Sequential(
      (0): Conv1d(1, 8, kernel_size=(3,), stride=(2,), padding=(1,))
      (1): LeakyReLU(negative_slope=0.01)
      (2): Conv1d(8, 16, kernel_size=(3,), stride=(2,), padding=(1,))
      (3): LeakyReLU(negative_slope=0.01)
      (4): Conv1d(16, 32, kernel_size=(3,), stride=(2,), padding=(1,))
      (5): LeakyReLU(negative_slope=0.01)
    )
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (mlp): Sequential(
      (0): Linear(in_features=960, out_features=100, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=100, out_features=50, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
      (4): Linear(in_features=50, out_features=10, bias=True)
    )
  )
  (scalar_spectral_combine_mlp): Sequential(
    (0): Linear(in_features=15, out_features=50, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=50, out_features=20, bias=True)
    (3): LeakyR

In [12]:
generator.parameters().__next__().requires_grad


True

In [13]:
rf_settings = torch.rand(3, 5)
formfactor = torch.rand(3, 240)

current_profile, bunch_length = generator(rf_settings, formfactor)
print(f"{current_profile.size() = }")
print(f"{current_profile.requires_grad = }")
print(f"{bunch_length.size() = }")
print(f"{bunch_length.requires_grad = }")


current_profile.size() = torch.Size([3, 300])
current_profile.requires_grad = True
bunch_length.size() = torch.Size([3, 1])
bunch_length.requires_grad = True


In [14]:
critic = Critic()
critic


Critic(
  (formfactor_encoder): ConvolutionalEncoder(
    (convnet): Sequential(
      (0): Conv1d(1, 8, kernel_size=(3,), stride=(2,), padding=(1,))
      (1): LeakyReLU(negative_slope=0.01)
      (2): Conv1d(8, 16, kernel_size=(3,), stride=(2,), padding=(1,))
      (3): LeakyReLU(negative_slope=0.01)
      (4): Conv1d(16, 32, kernel_size=(3,), stride=(2,), padding=(1,))
      (5): LeakyReLU(negative_slope=0.01)
    )
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (mlp): Sequential(
      (0): Linear(in_features=960, out_features=100, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=100, out_features=50, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
      (4): Linear(in_features=50, out_features=10, bias=True)
    )
  )
  (current_encoder): ConvolutionalEncoder(
    (convnet): Sequential(
      (0): Conv1d(1, 8, kernel_size=(3,), stride=(2,), padding=(1,))
      (1): LeakyReLU(negative_slope=0.01)
      (2): Conv1d(8, 16, kernel_size=(3,),

In [15]:
critique = critic(rf_settings, formfactor, current_profile, bunch_length)
print(f"{critique = }")
print(f"{critique.requires_grad = }")


critique = tensor([[-0.0552],
        [-0.0695],
        [-0.0499]], grad_fn=<AddmmBackward0>)
critique.requires_grad = True


In [16]:
dataset = EuXFELCurrentDataset(normalize=True)
dataset


<train_gan.EuXFELCurrentDataset at 0x2a9df0490>

In [17]:
len(dataset)


25600

In [18]:
(rf_settings, formfactor), (current_profile, bunch_length) = dataset[0]


In [19]:
formfactor


tensor([ 0.7084,  0.3383, -1.2126,  0.4945,  0.6809,  0.2575,  0.4530,  0.4780,
         0.5395,  0.5633,  0.7271,  0.3564,  0.5511,  0.8542,  0.6780,  0.3046,
         0.7898,  0.7712,  0.8822,  0.7930,  0.9108,  0.8205,  0.8450,  0.8070,
         0.8214,  0.8884,  0.8506,  0.9235,  0.8703,  0.8487,  0.8226,  0.8884,
         1.0721,  0.5000,  0.9647,  0.9065,  0.8927,  0.9351,  0.9809,  0.9506,
         0.9551,  1.0143,  1.0424,  1.0778,  1.0284,  1.0647,  1.0416,  1.1065,
         1.1268,  1.1140,  1.1473,  1.1615,  1.1791,  1.2142,  1.2281,  1.2646,
         1.3042,  1.3204,  1.3393,  1.3756,  1.3056,  1.3127,  1.2665,  1.3649,
         1.3368,  1.4493,  1.4306,  1.4408,  1.4224,  1.3839,  1.4319,  1.4403,
         1.4972,  1.5371,  1.5012,  1.5207,  1.5666,  1.5677,  1.6014,  1.6086,
         1.6435,  1.6570,  1.6868,  1.7112,  1.7433,  1.7737,  1.7999,  1.8310,
         1.8532,  1.8524,  1.6707,  1.6836,  1.8171,  1.8178,  1.8407,  1.8398,
         1.8740,  1.8806,  1.8759,  1.94

In [20]:
rf_settings


tensor([ 0.3616,  1.7155, -1.3814, -0.2878, -0.9114])

In [21]:
bunch_length


tensor([-1.0020])

In [22]:
current_profile


tensor([-0.4532, -0.4376, -0.4221, -0.4059, -0.3925, -0.3846, -0.3829, -0.3856,
        -0.3896, -0.3934, -0.3964, -0.3985, -0.3999, -0.4005, -0.4007, -0.4005,
        -0.4002, -0.4000, -0.3998, -0.3997, -0.3998, -0.4000, -0.4004, -0.4008,
        -0.4013, -0.4015, -0.4019, -0.4037, -0.4074, -0.4126, -0.4183, -0.4243,
        -0.4309, -0.4379, -0.4454, -0.4532, -0.4611, -0.4690, -0.4769, -0.4847,
        -0.4922, -0.4992, -0.5053, -0.5109, -0.5174, -0.5251, -0.5326, -0.5390,
        -0.5442, -0.5495, -0.5561, -0.5640, -0.5720, -0.5795, -0.5863, -0.5925,
        -0.5986, -0.6056, -0.6132, -0.6202, -0.6254, -0.6270, -0.6191, -0.5927,
        -0.5399, -0.4538, -0.3326, -0.1847, -0.0210,  0.1518,  0.3331,  0.5266,
         0.7321,  0.9510,  1.1886,  1.4529,  1.7598,  2.1348,  2.5822,  3.0652,
         3.4982,  3.7927,  3.9313,  3.9726,  4.0022,  4.0659,  4.1547,  4.2426,
         4.3249,  4.3937,  4.4215,  4.3959,  4.3457,  4.2917,  4.2158,  4.0946,
         3.9514,  3.8225,  3.7079,  3.58

In [23]:
data_module = EuXFELCurrentDataModule(batch_size=64, normalize=True)
data_module


<train_gan.EuXFELCurrentDataModule at 0x2a9df07c0>

In [24]:
data_module.setup(stage="fit")


In [25]:
data_loader = data_module.train_dataloader()
data_loader, len(data_loader)


(<torch.utils.data.dataloader.DataLoader at 0x2a9df02e0>, 400)

In [26]:
data_loader = data_module.val_dataloader()
data_loader, len(data_loader)


(<torch.utils.data.dataloader.DataLoader at 0x2a9df0b50>, 50)

In [27]:
data_loader = data_module.test_dataloader()
data_loader, len(data_loader)


(<torch.utils.data.dataloader.DataLoader at 0x2a9df0520>, 50)