## 加载库和配置

In [1]:
import torch
import pickle
from generative_trainer import GenerativeTrainer
from generative_models import QuantumGenerative
from utils import *
from torch.utils.data import TensorDataset, DataLoader

from types import SimpleNamespace



In [2]:
config = {
  "model": "QuantumGenerative",
  "i_qubits": 2,
  "o_qubits": 3,
  "a_qubits": 1,
  
  "n_layers": 30,
  "reuploading": 1,
  "two_qubit_gate_type": "CNOT",
  
  "dataset": "TFIM",
  "n_data": 2000,
  
  "loss_fn": "batch_dp_fidelity_loss",
  "batch_size": 8,
  "epochs": 10,
  "validate_freq": 20,
  "learning_rate": 1e-3,
}

config = SimpleNamespace(**config)

In [3]:
# # SPTS中实验配置
config.dataset = "SPTS"
config.categories = 4

# config.n_layers = 10
# config.reuploading = 3

# config.n_layers = 6
# config.reuploading = 5

# config.n_layers = 10
# config.reuploading = 5

# config.n_layers = 30


# # TFIM实验配置
config.dataset = "TFIM"
config.categories = 2

config.n_layers = 10
config.reuploading = 3

# config.n_layers = 6
# config.reuploading = 5

# config.n_layers = 10
# config.reuploading = 5

# config.n_layers = 30


# # todo 运行前确认，debug用
# config.i_qubits = 3
# config.n_layers = 2
# config.n_data = 2000
# config.epochs = 1
# config.a_qubits = 1
# config.reuploading = 2
# config.batch_size = 2


config.exp_name = '_'.join([f"{key}_{value}" for key, value in vars(config).items()])

## 加载数据集

In [4]:


phaseTransition = get_phaseTransition(config)
X, labels, ground_states = process_phase_transition_data(phaseTransition)

## 训练模型

In [5]:
# 注意：你需要修改 QuantumGenerative 模型以接受两个特征作为输入，并输出一个与基态相对应的量子态

# 创建并训练 QuantumTrainer
trainer = GenerativeTrainer(config)
trainer.load_data(X, ground_states)
trainer.train()

Validation Loss: 0.8665572507903299
Validation Loss: 0.7623634871318007
Validation Loss: 0.6764828339631588
Validation Loss: 0.6018153267419579
Validation Loss: 0.5356486268254431
Validation Loss: 0.4890175356644138
Validation Loss: 0.45279807574361114
Validation Loss: 0.4293949741537868
Validation Loss: 0.40976865863912465
Validation Loss: 0.38694335297293053
Validation Loss: 0.3754391938181933
Validation Loss: 0.3653356139242349
Epoch 1, Loss: 0.5149228552183923
Validation Loss: 0.3506977356875305
Validation Loss: 0.3398434330217503
Validation Loss: 0.33199783703468333
Validation Loss: 0.3247245440934001
Validation Loss: 0.3222039770298651
Validation Loss: 0.3148215688667049
Validation Loss: 0.3091276851781231
Validation Loss: 0.3029822979330418
Validation Loss: 0.2985988890092429
Validation Loss: 0.2989219925768204
Validation Loss: 0.3012453907815721
Epoch 2, Loss: 0.2995077636074934
Validation Loss: 0.28918452173382897
Validation Loss: 0.28776842297757954
Validation Loss: 0.2866637