<a href="https://colab.research.google.com/github/byrkbrk/calpagan-experiment/blob/main/calpagan_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# clone the repository
!git clone https://github.com/byrkbrk/calpagan-experiment.git

# unzip the dataset
!unzip calpagan-experiment/geant-delphes-train.zip

# relocate the datasets
!mv *-train.pt calpagan-experiment/

# reloacate the bins files 
!mv ./calpagan-experiment/*bins.npy ./

# add directory to path
import sys
sys.path.append("./calpagan-experiment/")

# install pyjet
!pip install pyjet

In [None]:
# import modules
from utils import * 
from models import UNet, Discriminator
from loss_functions import get_disc_loss, get_gen_loss
import torch.optim as optim

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
# data preparation

# read data
dir = "calpagan-experiment/"

geant = torch.load(dir + "geant-train.pt")
print("geant4 shape:", geant.shape)
delphes = torch.load(dir + "delphes-train.pt")
print("delphes shape:", delphes.shape)

# plot some images
fig, ax = plt.subplots(1, 2, figsize=(8, 8))
ax[0].imshow(geant[0].squeeze())
ax[0].set_title("geant4")
ax[1].imshow(delphes[0].squeeze())
ax[1].set_title("delphes")
plt.show()

# normalize the datasets
each_delphes_max = delphes.reshape(len(delphes), -1).\
                    max(axis=1).values.reshape(len(delphes), 1, 1, 1)
geant /= each_delphes_max
delphes /= each_delphes_max


In [None]:
# prepare dataloader
dataset = list(zip(geant, delphes, each_delphes_max))
dataloader = DataLoader(
    dataset,
    batch_size=64,
    shuffle=True
)

# instantiate generator
generator = UNet(2, 1).to(device) # Two channels input!
generator = generator.apply(weights_init)

# instantiate discriminator
discriminator = Discriminator(1).to(device)
discriminator = discriminator.apply(weights_init)

# define optimizer 
gen_opt = optim.Adam(generator.parameters(), lr=2e-4)
disc_opt = optim.Adam(discriminator.parameters(), lr=2e-4)


In [None]:
# train
n_epochs = 100
display_step = 100
jets_no = [0, 1, 2]
cur_step = 0
disc_loss = gen_loss = recon_loss = adv_loss = 0
loss_dict = {"steps": [], "disc_losses": [], "adv_losses": [],
            "recon_losses": [], "gen_losses": []
            }

for epoch in range(n_epochs):
    for batch in tqdm(dataloader):
        real, condition, batch_max_delphes = batch
        real = real.to(device)
        condition = condition.to(device)
        batch_max_delphes = batch_max_delphes.to(device)

        # mix real and condition
        noise = torch.rand_like(condition)
        condition_noise = torch.cat([condition, noise], axis=1)
        fake = generator(condition_noise)

        # discriminator
        disc_opt.zero_grad()
        cur_disc_loss = get_disc_loss(discriminator, real, fake)
        cur_disc_loss.backward()
        disc_opt.step()

        # generator
        gen_opt.zero_grad()
        cur_gen_loss, cur_recon_loss, cur_adv_loss = get_gen_loss(discriminator, real, fake)
        cur_gen_loss.backward()
        gen_opt.step()

        # update losses
        cur_step += 1
        disc_loss += cur_disc_loss.item() / display_step
        gen_loss += cur_gen_loss.item() / display_step
        recon_loss += cur_recon_loss.item() / display_step
        adv_loss += cur_adv_loss.item() / display_step

        # plot and print some statistics
        if cur_step % display_step == 0:
            print("Epoch: {} Steps: {} Disc Loss: {:.3f} Gen Loss: {:.3f} Recon Loss: {:.3f} Adv Loss: {:.3f}".
                  format(epoch, cur_step, disc_loss, gen_loss, recon_loss, adv_loss))
            show_multiple_images_and_jets(
                condition*batch_max_delphes, 
                real*batch_max_delphes, 
                fake*batch_max_delphes,
                cur_step, jets_no, 8, 1e-5, True, False)
            _ = show_loss_curve(loss_dict, cur_step, disc_loss, adv_loss)

            disc_loss = gen_loss = recon_loss = adv_loss = 0

        