In [2]:
import os
import time
import gc
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from scipy.stats import wasserstein_distance
from torch import nn, optim
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
# torch.manual_seed(SEED)
# torch.cuda.manual_seed(SEED)
# torch.cuda.manual_seed_all(SEED)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [None]:
data = np.load('./Project/data/data_nonrandom_responses.npz')['arr_0'][:20000]
data_cond = np.load('./Project/data/data_nonrandom_particles.npz')['arr_0'][:20000]
data_cond = pd.DataFrame(data_cond, columns=['Energy','Vx','Vy','Vz','Px','Py','Pz','mass','charge'])

print("Loaded data, shape:", data.shape, data_cond.shape)

In [None]:
data = np.log(data + 1).astype(np.float32)
data_cond = StandardScaler().fit_transform(data_cond).astype(np.float32)

print("Preprocessed data")

x_train, x_test, y_train, y_test = train_test_split(data, data_cond, test_size=0.2, shuffle=False)

x_train = torch.tensor(x_train).to(device)
y_train = torch.tensor(y_train).to(device)
x_test = torch.tensor(x_test).to(device)
y_test = torch.tensor(y_test).to(device)

batch_size = 128
train_dataset = TensorDataset(x_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, generator=torch.Generator().manual_seed(SEED))


def get_channel_masks(shape):
    n, m = shape
    pattern = np.array([[0, 1], [1, 0]])
    mask = np.ones((n, m))
    for i in range(n):
        for j in range(m):
            mask[i, j] = pattern[i % 2, j % 2]
    mask5 = 1 - mask
    mid_row, mid_col = n // 2, m // 2
    mask1 = mask.copy()
    mask2 = mask.copy()
    mask3 = mask.copy()
    mask4 = mask.copy()
    mask4[mid_row:, :] = 0
    mask4[:, :mid_col] = 0
    mask2[:, :mid_col] = 0
    mask2[:mid_row, :] = 0
    mask3[mid_row:, :] = 0
    mask3[:, mid_col:] = 0
    mask1[:, mid_col:] = 0
    mask1[:mid_row, :] = 0
    return mask1, mask2, mask3, mask4, mask5

def sum_channels_parallel(data):
    mask1, mask2, mask3, mask4, mask5 = get_channel_masks(data.shape[1:])
    mask1, mask2, mask3, mask4, mask5 = [torch.tensor(m).to(device) for m in [mask1, mask2, mask3, mask4, mask5]]
    ch1 = (data * mask1).sum(dim=[1, 2])
    ch2 = (data * mask2).sum(dim=[1, 2])
    ch3 = (data * mask3).sum(dim=[1, 2])
    ch4 = (data * mask4).sum(dim=[1, 2])
    ch5 = (data * mask5).sum(dim=[1, 2])
    return torch.stack([ch1, ch2, ch3, ch4, ch5], dim=1)

def calculate_ws_ch(generator, y_test, x_test, n_calc=5, batch_size=256):
    with torch.no_grad():
        org = torch.exp(x_test) - 1
        ch_org = org.view(-1, 44, 44)
        ch_org = sum_channels_parallel(ch_org).cpu().numpy()

        ws = np.zeros(5)
        n_samples = x_test.size(0)

        for _ in range(n_calc):
            ch_gen_list = []

            for i in range(0, n_samples, batch_size):
                end = min(i + batch_size, n_samples)
                z = torch.randn(end - i, 10, generator=torch.Generator().manual_seed(SEED)).to(device)
                y_batch = y_test[i:end]
                fake = generator(z, y_batch)
                fake = torch.exp(fake) - 1
                fake = fake.view(-1, 44, 44)
                ch_fake = sum_channels_parallel(fake).cpu().numpy()
                ch_gen_list.append(ch_fake)

            ch_gen = np.concatenate(ch_gen_list, axis=0)

            for i in range(5):
                ws[i] += wasserstein_distance(ch_org[:, i], ch_gen[:, i])

        ws /= n_calc
        print("ws mean", f"{ws.mean():.2f}", end=" ")
        for n, score in enumerate(ws):
            print(f"ch{n+1} {score:.2f}", end=" ")
        print()
        torch.cuda.empty_cache()
        gc.collect()

def generate_and_save_images(model, epoch, test_input, cond_input):
    model.eval()
    with torch.no_grad():
        predictions = model(test_input, cond_input).cpu().numpy()

    fig, axs = plt.subplots(2, 7, figsize=(15, 4))
    for i in range(14):
        if i < 7:
            x = x_test[20 + i].cpu().numpy().reshape(44, 44)
        else:
            x = predictions[i - 7].reshape(44, 44)
        im = axs[i // 7, i % 7].imshow(x, cmap='gnuplot')
        axs[i // 7, i % 7].axis('off')
        fig.colorbar(im, ax=axs[i // 7, i % 7])
    plt.savefig(f'image_at_epoch_{epoch:04d}.png')
    plt.show()
    plt.close()
    torch.cuda.empty_cache()
    gc.collect()

In [None]:
class Generator(nn.Module):
    def __init__(self, noise_dim, cond_dim):
        super(Generator, self).__init__()
        self.fc1 = nn.Sequential(
            nn.Linear(noise_dim + cond_dim, 256),
            nn.BatchNorm1d(256),
            nn.Dropout(0.2),
            nn.LeakyReLU(0.1, inplace=True)
        )
        self.fc2 = nn.Sequential(
            nn.Linear(256, 128 * 13 * 13),
            nn.BatchNorm1d(128 * 13 * 13),
            nn.Dropout(0.2),
            nn.LeakyReLU(0.1, inplace=True)
        )
        self.upsample = nn.Upsample(scale_factor=(2, 2))
        self.conv_layers = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3),
            nn.BatchNorm2d(256),
            nn.Dropout(0.2),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Upsample(scale_factor=(2, 2)),
            nn.Conv2d(256, 128, kernel_size=3),
            nn.BatchNorm2d(128),
            nn.Dropout(0.2),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(128, 64, kernel_size=2),
            nn.BatchNorm2d(64),
            nn.Dropout(0.2),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(64, 1, kernel_size=2),
            nn.ReLU(inplace=True)
        )

    def forward(self, noise, cond):
        x = torch.cat((noise, cond), dim=1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = x.view(-1, 128, 13, 13)
        x = self.upsample(x)
        x = self.conv_layers(x)
        return x

In [None]:
class Discriminator(nn.Module):
    def __init__(self, cond_dim):
        super(Discriminator, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Dropout(0.2),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(32, 16, kernel_size=3),
            nn.BatchNorm2d(16),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Dropout(0.2),
            nn.MaxPool2d(kernel_size=2)
        )
        self.fc1 = nn.Sequential(
            nn.Linear(9 * 12 * 12 + cond_dim, 128),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Dropout(0.2)
        )
        self.fc2 = nn.Sequential(
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Dropout(0.2)
        )
        self.fc3 = nn.Linear(64, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, img, cond):
        x = self.conv_layers(img)
        x = x.view(x.size(0), -1)
        x = torch.cat((x, cond), dim=1)
        x = self.fc1(x)
        latent = self.fc2(x)
        out = self.fc3(latent)
        out = self.sigmoid(out)
        return out, latent