In [1]:
import torch

from train_gan import (
    Critic,
    EuXFELCurrentDataModule,
    EuXFELCurrentDataset,
    Generator,
    SignalDecoder,
    SignalEncoder,
)


In [2]:
formfactor_encoder = SignalEncoder(signal_dims=240, latent_dims=10)
formfactor_encoder


SignalEncoder(
  (convnet): Sequential(
    (0): Conv1d(1, 8, kernel_size=(3,), stride=(2,), padding=(1,))
    (1): LeakyReLU(negative_slope=0.01)
    (2): Conv1d(8, 16, kernel_size=(3,), stride=(2,), padding=(1,))
    (3): LeakyReLU(negative_slope=0.01)
    (4): Conv1d(16, 32, kernel_size=(3,), stride=(2,), padding=(1,))
    (5): LeakyReLU(negative_slope=0.01)
  )
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (mlp): Sequential(
    (0): Linear(in_features=960, out_features=100, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=100, out_features=50, bias=True)
    (3): LeakyReLU(negative_slope=0.01)
    (4): Linear(in_features=50, out_features=10, bias=True)
  )
)

In [3]:
formfactor = torch.rand(3, 240)
encoded = formfactor_encoder(formfactor)
print(f"{encoded = }")
print(f"{encoded.requires_grad = }")


ENCODER signal.size() = torch.Size([3, 240]) signal.requires_grad = False
ENCODER after torch.unsqueeze(signal, dim=1) x.size() = torch.Size([3, 1, 240]) x.requires_grad = False
ENCODER EXPECT GRAD NEXT
ENCODER after self.convnet(x) x.size() = torch.Size([3, 32, 30]) x.requires_grad = True
ENCODER after self.flatten(x) x.size() = torch.Size([3, 960]) x.requires_grad = True
ENCODER encoded.size() = torch.Size([3, 10]) encoded.requires_grad = True
encoded = tensor([[ 0.1248,  0.0462, -0.0246,  0.0263,  0.0782, -0.0171,  0.0932, -0.1210,
          0.0595,  0.0265],
        [ 0.1248,  0.0478, -0.0253,  0.0280,  0.0787, -0.0158,  0.0945, -0.1197,
          0.0577,  0.0270],
        [ 0.1218,  0.0465, -0.0240,  0.0256,  0.0807, -0.0161,  0.0943, -0.1180,
          0.0617,  0.0255]], grad_fn=<AddmmBackward0>)
encoded.requires_grad = True


In [4]:
current_encoder = SignalEncoder(signal_dims=300, latent_dims=10)
current_encoder


SignalEncoder(
  (convnet): Sequential(
    (0): Conv1d(1, 8, kernel_size=(3,), stride=(2,), padding=(1,))
    (1): LeakyReLU(negative_slope=0.01)
    (2): Conv1d(8, 16, kernel_size=(3,), stride=(2,), padding=(1,))
    (3): LeakyReLU(negative_slope=0.01)
    (4): Conv1d(16, 32, kernel_size=(3,), stride=(2,), padding=(1,))
    (5): LeakyReLU(negative_slope=0.01)
  )
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (mlp): Sequential(
    (0): Linear(in_features=1216, out_features=100, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=100, out_features=50, bias=True)
    (3): LeakyReLU(negative_slope=0.01)
    (4): Linear(in_features=50, out_features=10, bias=True)
  )
)

In [5]:
current_profile = torch.rand(3, 300)
encoded = current_encoder(current_profile)
print(f"{encoded = }")
print(f"{encoded.requires_grad = }")


ENCODER signal.size() = torch.Size([3, 300]) signal.requires_grad = False
ENCODER after torch.unsqueeze(signal, dim=1) x.size() = torch.Size([3, 1, 300]) x.requires_grad = False
ENCODER EXPECT GRAD NEXT
ENCODER after self.convnet(x) x.size() = torch.Size([3, 32, 38]) x.requires_grad = True
ENCODER after self.flatten(x) x.size() = torch.Size([3, 1216]) x.requires_grad = True
ENCODER encoded.size() = torch.Size([3, 10]) encoded.requires_grad = True
encoded = tensor([[-0.0398, -0.0358,  0.0460,  0.0621,  0.1480,  0.0769, -0.0201,  0.1396,
         -0.0206,  0.0531],
        [-0.0433, -0.0331,  0.0438,  0.0589,  0.1445,  0.0784, -0.0195,  0.1412,
         -0.0218,  0.0506],
        [-0.0435, -0.0337,  0.0459,  0.0602,  0.1449,  0.0767, -0.0193,  0.1397,
         -0.0244,  0.0504]], grad_fn=<AddmmBackward0>)
encoded.requires_grad = True


In [6]:
current_decoder = SignalDecoder(latent_dims=10 + 5 + 1, signal_dims=300)
current_decoder


SignalDecoder(
  (mlp): Sequential(
    (0): Linear(in_features=16, out_features=50, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=50, out_features=100, bias=True)
    (3): LeakyReLU(negative_slope=0.01)
    (4): Linear(in_features=100, out_features=1216, bias=True)
    (5): LeakyReLU(negative_slope=0.01)
  )
  (unflatten): Unflatten(dim=1, unflattened_size=(32, 38))
  (convnet): Sequential(
    (0): ConvTranspose1d(32, 16, kernel_size=(3,), stride=(2,), padding=(1,))
    (1): LeakyReLU(negative_slope=0.01)
    (2): ConvTranspose1d(16, 8, kernel_size=(3,), stride=(2,), padding=(1,), output_padding=(1,))
    (3): LeakyReLU(negative_slope=0.01)
    (4): ConvTranspose1d(8, 1, kernel_size=(3,), stride=(2,), padding=(1,), output_padding=(1,))
    (5): ReLU()
  )
)

In [7]:
current_decoder.parameters().__next__().requires_grad

True

In [8]:
latent = torch.rand(3, 10 + 5 + 1)
decoded = current_decoder(latent)
print(f"{decoded = }")
print(f"{decoded.requires_grad = }")


SIGNAL DECODER encoded.size() = torch.Size([3, 16]) encoded.requires_grad = False
SIGNAL DECODER EXPECT GRAD NEXT
SIGNAL DECODER x.size() = torch.Size([3, 1216]) x.requires_grad = True
SIGNAL DECODER x.size() = torch.Size([3, 32, 38]) x.requires_grad = True
SIGNAL DECODER x.size() = torch.Size([3, 1, 300]) x.requires_grad = True
SIGNAL DECODER signal.size() = torch.Size([3, 300]) signal.requires_grad = True
decoded = tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0

In [9]:
formfactor_decoder = SignalDecoder(latent_dims=10 + 5 + 1, signal_dims=240)
formfactor_decoder


SignalDecoder(
  (mlp): Sequential(
    (0): Linear(in_features=16, out_features=50, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=50, out_features=100, bias=True)
    (3): LeakyReLU(negative_slope=0.01)
    (4): Linear(in_features=100, out_features=960, bias=True)
    (5): LeakyReLU(negative_slope=0.01)
  )
  (unflatten): Unflatten(dim=1, unflattened_size=(32, 30))
  (convnet): Sequential(
    (0): ConvTranspose1d(32, 16, kernel_size=(3,), stride=(2,), padding=(1,), output_padding=(1,))
    (1): LeakyReLU(negative_slope=0.01)
    (2): ConvTranspose1d(16, 8, kernel_size=(3,), stride=(2,), padding=(1,), output_padding=(1,))
    (3): LeakyReLU(negative_slope=0.01)
    (4): ConvTranspose1d(8, 1, kernel_size=(3,), stride=(2,), padding=(1,), output_padding=(1,))
    (5): ReLU()
  )
)

In [10]:
latent = torch.rand(3, 10 + 5 + 1)
decoded = formfactor_decoder(latent)
print(f"{decoded = }")
print(f"{decoded.requires_grad = }")


SIGNAL DECODER encoded.size() = torch.Size([3, 16]) encoded.requires_grad = False
SIGNAL DECODER EXPECT GRAD NEXT
SIGNAL DECODER x.size() = torch.Size([3, 960]) x.requires_grad = True
SIGNAL DECODER x.size() = torch.Size([3, 32, 30]) x.requires_grad = True
SIGNAL DECODER x.size() = torch.Size([3, 1, 240]) x.requires_grad = True
SIGNAL DECODER signal.size() = torch.Size([3, 240]) signal.requires_grad = True
decoded = tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.

In [11]:
generator = Generator()
generator


Generator(
  (formfactor_encoder): SignalEncoder(
    (convnet): Sequential(
      (0): Conv1d(1, 8, kernel_size=(3,), stride=(2,), padding=(1,))
      (1): LeakyReLU(negative_slope=0.01)
      (2): Conv1d(8, 16, kernel_size=(3,), stride=(2,), padding=(1,))
      (3): LeakyReLU(negative_slope=0.01)
      (4): Conv1d(16, 32, kernel_size=(3,), stride=(2,), padding=(1,))
      (5): LeakyReLU(negative_slope=0.01)
    )
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (mlp): Sequential(
      (0): Linear(in_features=960, out_features=100, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=100, out_features=50, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
      (4): Linear(in_features=50, out_features=10, bias=True)
    )
  )
  (current_decoder): SignalDecoder(
    (mlp): Sequential(
      (0): Linear(in_features=16, out_features=50, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=50, out_features=100, bias=True)
 

In [12]:
generator.parameters().__next__().requires_grad

True

In [14]:
formfactor = torch.rand(3, 240)
rf_settings = torch.rand(3, 5)
bunch_length = torch.rand(3, 1)

current_profile = generator(formfactor, rf_settings, bunch_length)
print(f"{current_profile.size() = }")
print(f"{current_profile.requires_grad = }")


GENERATOR formfactor.size() = torch.Size([3, 240]) formfactor.requires_grad = False
GENERATOR rf_settings.size() = torch.Size([3, 5]) rf_settings.requires_grad = False
GENERATOR bunch_length.size() = torch.Size([3, 1]) bunch_length.requires_grad = False
GENERATOR‚ EXPECT GRAD NEXT
ENCODER signal.size() = torch.Size([3, 240]) signal.requires_grad = False
ENCODER after torch.unsqueeze(signal, dim=1) x.size() = torch.Size([3, 1, 240]) x.requires_grad = False
ENCODER EXPECT GRAD NEXT
ENCODER after self.convnet(x) x.size() = torch.Size([3, 32, 30]) x.requires_grad = True
ENCODER after self.flatten(x) x.size() = torch.Size([3, 960]) x.requires_grad = True
ENCODER encoded.size() = torch.Size([3, 10]) encoded.requires_grad = True
GENERATOR encoded_formfactor.size() = torch.Size([3, 10]) encoded_formfactor.requires_grad = True
GENERATOR latent.size() = torch.Size([3, 16]) latent.requires_grad = True
SIGNAL DECODER encoded.size() = torch.Size([3, 16]) encoded.requires_grad = True
SIGNAL DECODER 

In [15]:
critic = Critic()
critic


Critic(
  (formfactor_encoder): SignalEncoder(
    (convnet): Sequential(
      (0): Conv1d(1, 8, kernel_size=(3,), stride=(2,), padding=(1,))
      (1): LeakyReLU(negative_slope=0.01)
      (2): Conv1d(8, 16, kernel_size=(3,), stride=(2,), padding=(1,))
      (3): LeakyReLU(negative_slope=0.01)
      (4): Conv1d(16, 32, kernel_size=(3,), stride=(2,), padding=(1,))
      (5): LeakyReLU(negative_slope=0.01)
    )
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (mlp): Sequential(
      (0): Linear(in_features=960, out_features=100, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=100, out_features=50, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
      (4): Linear(in_features=50, out_features=10, bias=True)
    )
  )
  (current_encoder): SignalEncoder(
    (convnet): Sequential(
      (0): Conv1d(1, 8, kernel_size=(3,), stride=(2,), padding=(1,))
      (1): LeakyReLU(negative_slope=0.01)
      (2): Conv1d(8, 16, kernel_size=(3,), stride=(2,), 

In [16]:
critique = critic(current_profile, formfactor, rf_settings, bunch_length)
print(f"{critique = }")
print(f"{critique.requires_grad = }")


ENCODER signal.size() = torch.Size([3, 300]) signal.requires_grad = True
ENCODER after torch.unsqueeze(signal, dim=1) x.size() = torch.Size([3, 1, 300]) x.requires_grad = True
ENCODER EXPECT GRAD NEXT
ENCODER after self.convnet(x) x.size() = torch.Size([3, 32, 38]) x.requires_grad = True
ENCODER after self.flatten(x) x.size() = torch.Size([3, 1216]) x.requires_grad = True
ENCODER encoded.size() = torch.Size([3, 10]) encoded.requires_grad = True
ENCODER signal.size() = torch.Size([3, 240]) signal.requires_grad = False
ENCODER after torch.unsqueeze(signal, dim=1) x.size() = torch.Size([3, 1, 240]) x.requires_grad = False
ENCODER EXPECT GRAD NEXT
ENCODER after self.convnet(x) x.size() = torch.Size([3, 32, 30]) x.requires_grad = True
ENCODER after self.flatten(x) x.size() = torch.Size([3, 960]) x.requires_grad = True
ENCODER encoded.size() = torch.Size([3, 10]) encoded.requires_grad = True
CRITIC (out) x.requires_grad = True
critique = tensor([[0.0519],
        [0.0489],
        [0.0452]],

In [14]:
dataset = EuXFELCurrentDataset()
dataset


<train_gan.EuXFELCurrentDataset at 0x15ff700a0>

In [15]:
(formfactor, rf_settings, bunch_length), current_profile = dataset[0]


In [16]:
formfactor


tensor([ 1.0044,  0.9630,  1.0036,  0.8943,  0.9871,  1.2385,  1.0386,  1.0309,
         0.9027,  1.0560,  0.9660,  0.9919,  0.9544,  1.0035,  0.9802,  0.9358,
         0.9007,  0.9520,  0.9279,  0.9389,  0.9560,  0.9737,  0.9385,  0.9257,
         0.9457,  0.9194,  0.9217,  0.9189,  0.8836,  0.9069,  0.8627,  0.9422,
         0.8833,  0.8639,  0.9203,  0.8686,  0.9047,  0.8307,  0.8406,  0.8705,
         0.8725,  0.8646,  0.8225,  0.8547,  0.8434,  0.8378,  0.8149,  0.8217,
         0.8124,  0.7939,  0.7970,  0.7831,  0.7742,  0.7711,  0.7550,  0.7392,
         0.7150,  0.7093,  0.6877,  0.6594,  0.7884,  0.6811,  0.6007,  0.5821,
         0.6001,  0.6237,  0.5887,  0.5902,  0.6090,  0.5914,  0.5975,  0.5878,
         0.5588,  0.5395,  0.5357,  0.5267,  0.5074,  0.4935,  0.4823,  0.4565,
         0.4283,  0.4146,  0.4020,  0.3654,  0.3516,  0.3306,  0.2884,  0.2757,
         0.2254,  0.2019,  0.2199,  0.3086,  0.2065,  0.2649,  0.2790,  0.2603,
         0.1840,  0.1837,  0.1600,  0.17

In [17]:
rf_settings


tensor([-7.4400e+00,  2.8130e+02,  4.5442e+04, -1.3430e+01,  2.1100e+00])

In [18]:
bunch_length


tensor(9.0783e-05)

In [19]:
current_profile


tensor([ 179.1833,  265.1158,  370.1751,  491.9395,  626.4266,  768.1013,
         908.7740, 1042.1052, 1162.9528, 1264.8722, 1348.4093, 1413.8678,
        1462.4076, 1493.7413, 1512.1744, 1519.9905, 1519.4473, 1512.3607,
        1502.1366, 1490.9547, 1480.4468, 1471.3280, 1463.6165, 1456.7324,
        1449.9124, 1442.8661, 1435.5945, 1428.8264, 1423.0869, 1419.1271,
        1416.9052, 1415.7997, 1415.0654, 1413.8822, 1411.8910, 1409.0577,
        1405.5623, 1401.7904, 1397.8761, 1393.8970, 1390.1157, 1387.0282,
        1384.8540, 1383.8727, 1383.8990, 1384.2417, 1383.5660, 1380.6753,
        1374.9377, 1366.2657, 1356.0144, 1346.1371, 1338.5245, 1334.9456,
        1335.8098, 1340.0797, 1346.2191, 1352.3209, 1356.5610, 1357.8591,
        1355.5807, 1349.4393, 1340.0415, 1328.5688, 1316.5232, 1305.4174,
        1296.7922, 1291.0576, 1288.1896, 1287.8152, 1289.4836, 1292.4856,
        1296.4257, 1300.7830, 1304.6067, 1307.0728, 1307.5531, 1305.6469,
        1301.3713, 1295.5537, 1288.794

In [20]:
data_module = EuXFELCurrentDataModule(batch_size=64)
data_module

<train_gan.EuXFELCurrentDataModule at 0x15ff700d0>

In [21]:
data_module.setup(stage="train")

In [22]:
data_loader = data_module.train_dataloader()
data_loader

<torch.utils.data.dataloader.DataLoader at 0x15ff705e0>

In [23]:
len(data_loader)

300