In [None]:
from utils import * 
from models import UNet, Discriminator
from loss_functions import get_disc_loss, get_gen_loss
import torch.optim as optim

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
create_directories()

### Data Preparation


In [None]:
# Load data
geant_path = "geant_72x72.pt"
geant = torch.load(geant_path)
print(geant.shape)

delphes_path = "delphes_72x72.pt"
delphes = torch.load(delphes_path)
print(delphes.shape)

# Get max "pixel" value
max_geant = geant.max()
max_delphes = delphes.max()
print("geant max:", max_geant.item())
print("delphes max:", max_delphes.item())

# Normalize the data
max_each_delphes = delphes.reshape(len(delphes), -1).max(axis=1).values.reshape(len(delphes), 1, 1, 1)
geant /= max_each_delphes
delphes /= max_each_delphes

# Separate into train and test sets
n_train_examples = 8500
test_geant, test_delphes = geant[n_train_examples:], delphes[n_train_examples:] # test data
geant, delphes = geant[:n_train_examples], delphes[:n_train_examples] # training data
test_max_each_delphes = max_each_delphes[n_train_examples:]
train_max_each_delphes = max_each_delphes[:n_train_examples]
print("Training shape:", geant.shape, delphes.shape)
print("Test shape:", test_geant.shape, test_delphes.shape)

# Save data information
data_info_dict = {
    "n_train_examples": n_train_examples,
    "train_max_each_delphes": train_max_each_delphes,
    "test_max_each_delphes": test_max_each_delphes
}
with open("data_info_dict.pickle", "wb") as f:
    pickle.dump(data_info_dict, f)

### Initializations

In [12]:
n_epochs = 100
batch_size = 64
lr = 2e-4
display_step = 100
jets_no = [0, 1, 2]

# Dataloader
dataset = list(zip(geant, delphes, train_max_each_delphes))
dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True
)

# Generator
generator = UNet(2, 1).to(device) # Two channels input!
generator = generator.apply(weights_init)

# Discriminator
discriminator = Discriminator(1).to(device)
discriminator = discriminator.apply(weights_init)

# Optimizers
gen_opt = optim.Adam(generator.parameters(), lr=lr)
disc_opt = optim.Adam(discriminator.parameters(), lr=lr)

### Training

In [None]:
start_t = time()
cur_step = 0
mean_discriminator_loss = 0
mean_generator_loss = 0
mean_recon_loss = 0
mean_adv_loss = 0
loss_dict = {"steps": [], "disc_losses": [], "adv_losses": [],
            "recon_losses": [], "gen_losses": []
            }

for epoch in range(n_epochs):
    
    if epoch % 10 == 0 and epoch > 0:
        print(f"Epoch: {epoch}, Steps: {cur_step}")
        save_gen_params(generator, epoch)
        save_disc_params(discriminator, epoch)
        show_multiple_images_and_jets(condition, real, fake, batch_max_delphes, jets_no)

    for batch in tqdm(dataloader):
        real, condition, batch_max_delphes = batch
        real = real.to(device)
        condition = condition.to(device)
        batch_max_delphes = batch_max_delphes.to(device)
        
        # Mix real and condition
        noise = torch.rand_like(condition)
        condition_noise = torch.cat([condition, noise], axis=1)
        fake = generator(condition_noise)
        
        # Discriminator
        disc_opt.zero_grad()
        cur_disc_loss = get_disc_loss(discriminator, real, fake)
        cur_disc_loss.backward()
        disc_opt.step()
        
        # Generator
        gen_opt.zero_grad()
        cur_gen_loss, recon_loss, adv_loss = get_gen_loss(discriminator, real, fake)
        cur_gen_loss.backward()
        gen_opt.step()
        
        # Update losses
        cur_step += 1
        mean_discriminator_loss += cur_disc_loss.item() / display_step
        mean_generator_loss += cur_gen_loss.item() / display_step
        mean_recon_loss += recon_loss.item() / display_step
        mean_adv_loss += adv_loss.item() / display_step

        if cur_step % display_step == 0:
            print("Epoch: {} Steps: {} Disc Loss: {:.3f} Gen Loss: {:.3f} Recon Loss: {:.3f} Adv Loss: {:.3f}".
                  format(epoch, cur_step, mean_discriminator_loss, mean_generator_loss, mean_recon_loss, mean_adv_loss))
            show_multiple_images_and_jets(
                condition*batch_max_delphes, 
                real*batch_max_delphes, 
                fake*batch_max_delphes,
                cur_step, jets_no, 8, 1e-5, True, True)
            loss_dict = show_loss_curve(loss_dict, cur_step, mean_discriminator_loss, mean_adv_loss)

            mean_discriminator_loss = 0
            mean_generator_loss = 0
            mean_recon_loss = 0
            mean_adv_loss = 0

# Save final model parameters
save_gen_params(generator, cur_step)
save_disc_params(discriminator, cur_step)
print(f"Cell completed in {(time() - start_t)/3600:.3f} hours")
