In [1]:
import torch

from train_gan import (
    Critic,
    EuXFELCurrentDataModule,
    EuXFELCurrentDataset,
    Generator,
    ConvolutionalDecoder,
    ConvolutionalEncoder,
)


In [2]:
formfactor_encoder = ConvolutionalEncoder(signal_dims=240, latent_dims=10)
formfactor_encoder


ConvolutionalEncoder(
  (convnet): Sequential(
    (0): Conv1d(1, 8, kernel_size=(3,), stride=(2,), padding=(1,))
    (1): LeakyReLU(negative_slope=0.01)
    (2): Conv1d(8, 16, kernel_size=(3,), stride=(2,), padding=(1,))
    (3): LeakyReLU(negative_slope=0.01)
    (4): Conv1d(16, 32, kernel_size=(3,), stride=(2,), padding=(1,))
    (5): LeakyReLU(negative_slope=0.01)
  )
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (mlp): Sequential(
    (0): Linear(in_features=960, out_features=100, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=100, out_features=50, bias=True)
    (3): LeakyReLU(negative_slope=0.01)
    (4): Linear(in_features=50, out_features=10, bias=True)
  )
)

In [3]:
formfactor = torch.rand(3, 240)
encoded = formfactor_encoder(formfactor)
print(f"{encoded = }")
print(f"{encoded.requires_grad = }")


encoded = tensor([[-0.0393,  0.0811,  0.1199,  0.0599, -0.0660,  0.0160, -0.1065, -0.0042,
          0.1152, -0.1126],
        [-0.0406,  0.0854,  0.1207,  0.0606, -0.0653,  0.0146, -0.1049, -0.0053,
          0.1116, -0.1091],
        [-0.0426,  0.0803,  0.1253,  0.0619, -0.0626,  0.0149, -0.1068, -0.0015,
          0.1141, -0.1099]], grad_fn=<AddmmBackward0>)
encoded.requires_grad = True


In [4]:
current_encoder = ConvolutionalEncoder(signal_dims=300, latent_dims=10)
current_encoder


ConvolutionalEncoder(
  (convnet): Sequential(
    (0): Conv1d(1, 8, kernel_size=(3,), stride=(2,), padding=(1,))
    (1): LeakyReLU(negative_slope=0.01)
    (2): Conv1d(8, 16, kernel_size=(3,), stride=(2,), padding=(1,))
    (3): LeakyReLU(negative_slope=0.01)
    (4): Conv1d(16, 32, kernel_size=(3,), stride=(2,), padding=(1,))
    (5): LeakyReLU(negative_slope=0.01)
  )
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (mlp): Sequential(
    (0): Linear(in_features=1216, out_features=100, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=100, out_features=50, bias=True)
    (3): LeakyReLU(negative_slope=0.01)
    (4): Linear(in_features=50, out_features=10, bias=True)
  )
)

In [5]:
current_profile = torch.rand(3, 300)
encoded = current_encoder(current_profile)
print(f"{encoded = }")
print(f"{encoded.requires_grad = }")


encoded = tensor([[-0.0733,  0.0361,  0.0691, -0.1389, -0.0472, -0.0301,  0.0685,  0.1085,
         -0.0122, -0.0509],
        [-0.0738,  0.0387,  0.0680, -0.1381, -0.0491, -0.0315,  0.0672,  0.1079,
         -0.0128, -0.0521],
        [-0.0738,  0.0409,  0.0709, -0.1325, -0.0459, -0.0315,  0.0656,  0.1074,
         -0.0132, -0.0517]], grad_fn=<AddmmBackward0>)
encoded.requires_grad = True


In [6]:
current_decoder = ConvolutionalDecoder(latent_dims=10 + 5 + 1, signal_dims=300)
current_decoder


ConvolutionalDecoder(
  (mlp): Sequential(
    (0): Linear(in_features=16, out_features=50, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=50, out_features=100, bias=True)
    (3): LeakyReLU(negative_slope=0.01)
    (4): Linear(in_features=100, out_features=1216, bias=True)
    (5): LeakyReLU(negative_slope=0.01)
  )
  (unflatten): Unflatten(dim=1, unflattened_size=(32, 38))
  (convnet): Sequential(
    (0): ConvTranspose1d(32, 16, kernel_size=(3,), stride=(2,), padding=(1,))
    (1): LeakyReLU(negative_slope=0.01)
    (2): ConvTranspose1d(16, 8, kernel_size=(3,), stride=(2,), padding=(1,), output_padding=(1,))
    (3): LeakyReLU(negative_slope=0.01)
    (4): ConvTranspose1d(8, 1, kernel_size=(3,), stride=(2,), padding=(1,), output_padding=(1,))
    (5): ReLU()
  )
)

In [7]:
current_decoder.parameters().__next__().requires_grad


True

In [8]:
latent = torch.rand(3, 10 + 5 + 1)
decoded = current_decoder(latent)
print(f"{decoded = }")
print(f"{decoded.requires_grad = }")


decoded = tensor([[0.0355, 0.1487, 0.0441, 0.1156, 0.0283, 0.1418, 0.0135, 0.1049, 0.0328,
         0.1448, 0.0355, 0.1072, 0.0214, 0.1316, 0.0104, 0.1116, 0.0271, 0.1539,
         0.0343, 0.1019, 0.0249, 0.1396, 0.0261, 0.1059, 0.0251, 0.1328, 0.0339,
         0.1187, 0.0279, 0.1428, 0.0208, 0.1125, 0.0299, 0.1541, 0.0083, 0.0921,
         0.0286, 0.1398, 0.0042, 0.1082, 0.0246, 0.1470, 0.0522, 0.1215, 0.0211,
         0.1342, 0.0000, 0.1140, 0.0333, 0.1581, 0.0306, 0.1058, 0.0286, 0.1353,
         0.0079, 0.1156, 0.0271, 0.1461, 0.0283, 0.1103, 0.0328, 0.1359, 0.0195,
         0.1135, 0.0254, 0.1417, 0.0515, 0.1285, 0.0327, 0.1401, 0.0236, 0.1130,
         0.0248, 0.1394, 0.0334, 0.1194, 0.0299, 0.1431, 0.0330, 0.1147, 0.0286,
         0.1422, 0.0307, 0.1072, 0.0259, 0.1374, 0.0215, 0.1096, 0.0291, 0.1372,
         0.0377, 0.1099, 0.0218, 0.1382, 0.0221, 0.1093, 0.0275, 0.1425, 0.0095,
         0.0959, 0.0277, 0.1412, 0.0211, 0.1149, 0.0246, 0.1420, 0.0366, 0.1146,
         0.0313, 0

In [9]:
formfactor_decoder = ConvolutionalDecoder(latent_dims=10 + 5 + 1, signal_dims=240)
formfactor_decoder


ConvolutionalDecoder(
  (mlp): Sequential(
    (0): Linear(in_features=16, out_features=50, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=50, out_features=100, bias=True)
    (3): LeakyReLU(negative_slope=0.01)
    (4): Linear(in_features=100, out_features=960, bias=True)
    (5): LeakyReLU(negative_slope=0.01)
  )
  (unflatten): Unflatten(dim=1, unflattened_size=(32, 30))
  (convnet): Sequential(
    (0): ConvTranspose1d(32, 16, kernel_size=(3,), stride=(2,), padding=(1,), output_padding=(1,))
    (1): LeakyReLU(negative_slope=0.01)
    (2): ConvTranspose1d(16, 8, kernel_size=(3,), stride=(2,), padding=(1,), output_padding=(1,))
    (3): LeakyReLU(negative_slope=0.01)
    (4): ConvTranspose1d(8, 1, kernel_size=(3,), stride=(2,), padding=(1,), output_padding=(1,))
    (5): ReLU()
  )
)

In [10]:
latent = torch.rand(3, 10 + 5 + 1)
decoded = formfactor_decoder(latent)
print(f"{decoded = }")
print(f"{decoded.requires_grad = }")


decoded = tensor([[0.3804, 0.1999, 0.2930, 0.2704, 0.3945, 0.2424, 0.3304, 0.2652, 0.3852,
         0.2451, 0.3454, 0.2422, 0.3703, 0.1961, 0.3121, 0.2568, 0.3901, 0.2057,
         0.3108, 0.2345, 0.3556, 0.2152, 0.3359, 0.2557, 0.3784, 0.2205, 0.3180,
         0.2689, 0.3981, 0.2214, 0.3174, 0.2381, 0.3766, 0.2258, 0.3342, 0.2518,
         0.3649, 0.2187, 0.3130, 0.2421, 0.3708, 0.2008, 0.3244, 0.2476, 0.3769,
         0.2112, 0.3086, 0.2489, 0.3793, 0.2085, 0.3416, 0.2453, 0.3693, 0.2167,
         0.2994, 0.2462, 0.3780, 0.2090, 0.2972, 0.2601, 0.3803, 0.2080, 0.3357,
         0.2604, 0.3832, 0.2199, 0.3416, 0.2695, 0.3751, 0.2250, 0.3093, 0.2344,
         0.3715, 0.2310, 0.3437, 0.2533, 0.3693, 0.2076, 0.2977, 0.2413, 0.3652,
         0.1841, 0.3266, 0.2417, 0.3718, 0.2159, 0.3322, 0.2646, 0.3925, 0.2159,
         0.3040, 0.2416, 0.3752, 0.1862, 0.3167, 0.2422, 0.3774, 0.2132, 0.3261,
         0.2603, 0.3885, 0.2311, 0.3176, 0.2304, 0.3747, 0.1726, 0.3208, 0.2621,
         0.3688, 0

In [11]:
generator = Generator()
generator


Generator(
  (formfactor_encoder): ConvolutionalEncoder(
    (convnet): Sequential(
      (0): Conv1d(1, 8, kernel_size=(3,), stride=(2,), padding=(1,))
      (1): LeakyReLU(negative_slope=0.01)
      (2): Conv1d(8, 16, kernel_size=(3,), stride=(2,), padding=(1,))
      (3): LeakyReLU(negative_slope=0.01)
      (4): Conv1d(16, 32, kernel_size=(3,), stride=(2,), padding=(1,))
      (5): LeakyReLU(negative_slope=0.01)
    )
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (mlp): Sequential(
      (0): Linear(in_features=960, out_features=100, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=100, out_features=50, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
      (4): Linear(in_features=50, out_features=10, bias=True)
    )
  )
  (current_decoder): ConvolutionalDecoder(
    (mlp): Sequential(
      (0): Linear(in_features=16, out_features=50, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=50, out_features=100

In [12]:
generator.parameters().__next__().requires_grad


True

In [13]:
formfactor = torch.rand(3, 240)
rf_settings = torch.rand(3, 5)
bunch_length = torch.rand(3, 1)

current_profile = generator(formfactor, rf_settings, bunch_length)
print(f"{current_profile.size() = }")
print(f"{current_profile.requires_grad = }")


current_profile.size() = torch.Size([3, 300])
current_profile.requires_grad = True


In [14]:
critic = Critic()
critic


Critic(
  (formfactor_encoder): ConvolutionalEncoder(
    (convnet): Sequential(
      (0): Conv1d(1, 8, kernel_size=(3,), stride=(2,), padding=(1,))
      (1): LeakyReLU(negative_slope=0.01)
      (2): Conv1d(8, 16, kernel_size=(3,), stride=(2,), padding=(1,))
      (3): LeakyReLU(negative_slope=0.01)
      (4): Conv1d(16, 32, kernel_size=(3,), stride=(2,), padding=(1,))
      (5): LeakyReLU(negative_slope=0.01)
    )
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (mlp): Sequential(
      (0): Linear(in_features=960, out_features=100, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=100, out_features=50, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
      (4): Linear(in_features=50, out_features=10, bias=True)
    )
  )
  (current_encoder): ConvolutionalEncoder(
    (convnet): Sequential(
      (0): Conv1d(1, 8, kernel_size=(3,), stride=(2,), padding=(1,))
      (1): LeakyReLU(negative_slope=0.01)
      (2): Conv1d(8, 16, kernel_size=(3,),

In [15]:
critique = critic(current_profile, formfactor, rf_settings, bunch_length)
print(f"{critique = }")
print(f"{critique.requires_grad = }")


critique = tensor([[-0.0183],
        [ 0.0024],
        [-0.0005]], grad_fn=<AddmmBackward0>)
critique.requires_grad = True


In [16]:
dataset = EuXFELCurrentDataset()
dataset


<train_gan.EuXFELCurrentDataset at 0x296c970d0>

In [23]:
len(dataset)


32000

In [24]:
(formfactor, rf_settings, bunch_length), current_profile = dataset[0]


In [25]:
formfactor


tensor([ 1.1136,  0.9786,  1.0596,  0.9905,  0.9889,  0.8491,  0.9743,  0.9680,
         0.9987,  0.9678,  0.9987,  0.8583,  0.9538,  0.9625,  0.9760,  1.1724,
         0.9697,  0.9435,  0.9241,  0.9130,  0.9376,  0.9233,  0.8949,  0.8922,
         0.9359,  0.9103,  0.9332,  0.9133,  0.8929,  0.9105,  0.8257,  0.8823,
         0.8891,  0.8212,  0.8671,  0.8750,  0.8668,  0.8590,  0.8542,  0.8362,
         0.8528,  0.8493,  0.8486,  0.8386,  0.8386,  0.8227,  0.8215,  0.8296,
         0.8120,  0.8044,  0.7988,  0.7836,  0.7764,  0.7696,  0.7519,  0.7426,
         0.7177,  0.6808,  0.6909,  0.6645,  0.6718,  0.5611,  0.6621,  0.5888,
         0.6114,  0.6397,  0.6155,  0.6192,  0.5929,  0.6077,  0.5308,  0.5788,
         0.5683,  0.5319,  0.5339,  0.5141,  0.5032,  0.4987,  0.4759,  0.4483,
         0.4318,  0.4187,  0.3994,  0.3706,  0.3445,  0.3279,  0.3064,  0.2693,
         0.2370,  0.2587, -0.2030,  0.2165,  0.3358,  0.1943,  0.3144,  0.1282,
         0.1266,  0.1986,  0.1814,  0.15

In [26]:
rf_settings


tensor([-7.4400e+00,  2.8130e+02,  4.5442e+04, -1.3430e+01,  2.1100e+00])

In [27]:
bunch_length


tensor([9.0783e-05])

In [28]:
current_profile


tensor([ 179.1833,  265.1158,  370.1751,  491.9395,  626.4266,  768.1013,
         908.7740, 1042.1052, 1162.9528, 1264.8722, 1348.4093, 1413.8678,
        1462.4076, 1493.7413, 1512.1744, 1519.9905, 1519.4473, 1512.3607,
        1502.1366, 1490.9547, 1480.4468, 1471.3280, 1463.6165, 1456.7324,
        1449.9124, 1442.8661, 1435.5945, 1428.8264, 1423.0869, 1419.1271,
        1416.9052, 1415.7997, 1415.0654, 1413.8822, 1411.8910, 1409.0577,
        1405.5623, 1401.7904, 1397.8761, 1393.8970, 1390.1157, 1387.0282,
        1384.8540, 1383.8727, 1383.8990, 1384.2417, 1383.5660, 1380.6753,
        1374.9377, 1366.2657, 1356.0144, 1346.1371, 1338.5245, 1334.9456,
        1335.8098, 1340.0797, 1346.2191, 1352.3209, 1356.5610, 1357.8591,
        1355.5807, 1349.4393, 1340.0415, 1328.5688, 1316.5232, 1305.4174,
        1296.7922, 1291.0576, 1288.1896, 1287.8152, 1289.4836, 1292.4856,
        1296.4257, 1300.7830, 1304.6067, 1307.0728, 1307.5531, 1305.6469,
        1301.3713, 1295.5537, 1288.794

In [29]:
data_module = EuXFELCurrentDataModule(batch_size=64)
data_module


<train_gan.EuXFELCurrentDataModule at 0x296c97580>

In [30]:
data_module.setup(stage="train")


In [31]:
data_loader = data_module.train_dataloader()
data_loader


<torch.utils.data.dataloader.DataLoader at 0x296c97670>

In [32]:
len(data_loader)


300