In [9]:
import qutip as qt
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

print("CUDA AVALIABLE:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("DEVICE:", torch.cuda.get_device_name(torch.cuda.current_device()))
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")


CUDA AVALIABLE: True
DEVICE: NVIDIA RTX A4500
Using cuda device


In [7]:
# Custom Imports
from data_gen import read_data, dset_size

## Load Data

In [None]:
file = "data.hdf5"
complex_type = torch.complex64
states, H = read_data(file, dim=2, label="train", i_start=0, i_end=1000)

In [None]:
d = torch.tensor(H[0, 0], dtype=complex_type).to(device)
d, d.shape

# Try a Model with Complex Data

In [None]:
class Net(nn.Module):
    def __init__(self):
      super(Net, self).__init__()
      self.flatten = nn.Flatten(start_dim=0, end_dim=-1)
      self.dense1= nn.Sequential(
        nn.Linear(4, 2, dtype=complex_type),
        nn.Sigmoid()
      )
      self.dense2 = nn.Sequential(
        nn.Linear(2, 1, dtype=complex_type),
        nn.Sigmoid()
      )

    # x represents our data
    def forward(self, x):

      x = self.flatten(x)
      print(x)
      x = self.dense1(x)
      x = self.dense2(x)

      return x

model = Net().to(device)
print(model)

In [None]:
model(d)

# Make a PyTorch Data Loader for Data

In [4]:
class myDataset(Dataset):
    def __init__(self, hfile, dim, label):
        self.hfile = hfile
        self.dim = dim
        self.shape = dset_size(hfile, dim, label)

    def __len__(self):
        return self.shape[0]

    def __getitem__(self, idx):
       return read_data(self.hfile, self.dim, self.label, idx, idx+1)

In [10]:
hfile = "data.hdf5"
dim = 2
train_data = myDataset(hfile, dim, "train")
train_dataloader = DataLoader(train_data, batch_size=32, shuffle=True)
test_data = myDataset(hfile, dim, "test")
test_dataloader = DataLoader(test_data, batch_size=32, shuffle=True)