In [1]:
import json
import numpy as np
import torch
import torch.nn.functional as F

In [3]:
def load_json(file_path):
    # Open the JSON file and load its content
    with open(file_path, 'r') as f:
        data = json.load(f)['data']  # Assume relevant data is under the 'data' key

    # Extract visual field matrices from all patients and both eyes
    X = [
        np.array(session['hvf']).flatten()  # Flatten each matrix to a 1D vector
        for patient_id in data
        for eye in ['R', 'L']  # Process both right and left eyes
        if eye in data[patient_id]  # Check if data for the eye exists
        for session in data[patient_id][eye]  # Extract sessions for the given eye
    ]

    # Convert the list of vectors to a NumPy array
    X = np.array(X)
    print(f"Loaded dataset with shape: {X.shape}")  # Log the shape of the dataset
    return X

In [5]:
def gumbel_softmax(logits, temperature=1.0):
    # Generate Gumbel noise with the same shape as logits
    gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits)))

    # Add noise to the logits and apply softmax normalization
    y = (logits + gumbel_noise) / temperature

    # Return normalized probabilities
    return F.softmax(y, dim=-1)

In [7]:
if __name__ == "__main__":
    # Load data from a JSON file
    data = load_json('alldata.json')

    # Convert the loaded data to a PyTorch tensor
    logits = torch.tensor(data, dtype=torch.float32)

    # Apply Gumbel-Softmax to get discrete samples
    discrete_data = gumbel_softmax(logits, temperature=0.5)

    # Print the discrete data
    print("Discrete Data (Gumbel-Softmax Output):", discrete_data)

Loaded dataset with shape: (28943, 72)
Discrete Data (Gumbel-Softmax Output): tensor([[1.2506e-01, 9.0049e-02, 1.6940e-02,  ..., 0.0000e+00, 1.7471e-02,
         7.5672e-03],
        [1.3855e-03, 1.2340e-03, 2.7843e-05,  ..., 0.0000e+00, 1.3842e-04,
         6.2672e-05],
        [3.5049e-04, 1.4990e-05, 1.9932e-04,  ..., 0.0000e+00, 2.4023e-05,
         3.9524e-06],
        ...,
        [3.7879e-03, 2.2456e-02, 3.6819e-02,  ..., 0.0000e+00, 2.5020e-03,
         1.3046e-02],
        [1.1567e-03, 1.3408e-04, 9.4672e-04,  ..., 0.0000e+00, 1.0159e-02,
         1.0214e-02],
        [9.3732e-02, 7.0110e-03, 7.3696e-02,  ..., 0.0000e+00, 2.4426e-04,
         1.4998e-02]])
