### Imports and Setup

In [None]:
import torch
import classiq

import torch.nn as nn
import torch.optim as optim
from torchinfo import summary

import sys
sys.path.append("../..") # Add the parent directory to the sys.path list

from models.leqm3 import linear_entanglement_r3_quantum_model
from models.qnn import execute_fn, post_process_fn, QNN

from scripts.helper import create_writer, write_train_results
from scripts.data_setup import create_mnist_dataloaders
from scripts.data_transforms import input_transform, target_transform
from scripts.train import train
from scripts.test import test
from scripts.save_model import save_model

In [None]:
## Authenticate Classiq
# classiq.authenticate()

In [None]:
## For setting up device agnostic code
# device = "cuda" if torch.cuda.is_available() else "cpu"
device = 'cpu'
device

In [None]:
# ## Clear Output Files
# post_process_output_file = open("post_process_output.txt", "w")
# print("-----------------------------------------------------------------------------------------------------------------", file=post_process_output_file)
# print("--------------------------------------------POST PROCESS OUTPUT--------------------------------------------------", file=post_process_output_file)
# print("-----------------------------------------------------------------------------------------------------------------", file=post_process_output_file)
# post_process_output_file.close()

# test_loop_output_file = open("test_loop_output.txt", "w")
# print("-----------------------------------------------------------------------------------------------------------------", file=test_loop_output_file)
# print("-----------------------------------------------TEST LOOP OUTPUT--------------------------------------------------", file=test_loop_output_file)
# print("-----------------------------------------------------------------------------------------------------------------", file=test_loop_output_file)
# test_loop_output_file.close()

In [None]:
## HYPER PARAMETERS
_LEARNING_RATE = 1.0
BATCH_SIZE = 32
EPOCHS = 2

### Quantum Model

In [None]:
## Create a Linear Entanglement Quantum Model for MNIST Data Classification with three linear entanglement layers of RXX, RYY, and RZZ.
quantum_model = linear_entanglement_r3_quantum_model()

In [None]:
quantum_program = classiq.synthesize(quantum_model)

In [None]:
# View Quantum Program on Classiq Platform
# classiq.show(quantum_program)

### Quantum Neural Network

In [None]:
qnn = QNN(
    quantum_program=quantum_program,
    execute=execute_fn,
    post_process=post_process_fn,
)

In [None]:
# summary(model=qnn, input_size=(32, 16), verbose=0, col_names=["input_size", "output_size", "num_params", "trainable"], col_width=20, row_settings=["var_names"])

In [None]:
# choosing our loss function
loss_fn = nn.L1Loss()
# choosing our optimizer
optimizer = optim.SGD(qnn.parameters(), lr=_LEARNING_RATE)

### Preparing Data

In [None]:
train_dataloader, test_dataloader, class_names = create_mnist_dataloaders(
    root="../../data",
    transform=input_transform,
    target_transform=target_transform,
    batch_size=BATCH_SIZE,
    create_subset=True,
    subset_size=256
)

In [None]:
# Let's check out what we've created
print(f"Dataloaders: {train_dataloader, test_dataloader}") 
print(f"Length of train dataloader: {len(train_dataloader)} batches of {BATCH_SIZE}")
print(f"Length of test dataloader: {len(test_dataloader)} batches of {BATCH_SIZE}")
print(f"Our Dataset have following classes: {class_names}")

In [None]:
data, label = next(iter(train_dataloader))

print(f"Image shape: {data.shape} -> [batch_size, pixel_angle]")
print(f"Label shape: {label.shape} -> [batch_size, label_value]")

#### Run Experiment

##### 01. Train

In [None]:
# Create a writer for tracking our experiment
writer = create_writer(experiment_name="data_256", model_name="linear_entanglement_r3", extra=f"{EPOCHS}_epochs")

In [None]:
train_results = train(
    model = qnn, 
    data_loader = train_dataloader, 
    loss_fn = loss_fn, 
    optimizer = optimizer, 
    writer = writer, 
    epochs = EPOCHS,
    device = device
)

In [None]:
# Check out the model results
print(train_results)

In [None]:
write_train_results(experiment_name="exp_2_data_256", model_name="linear_entanglement_r3", epochs=EPOCHS, results=train_results)

In [None]:
# %load_ext tensorboard
# %tensorboard --logdir runs

##### 02. Save the trained model

In [None]:
save_model(
    model=qnn,
    target_dir='outputs/saved_models',
    model_name=f'exp_2_leqmr3_subset256_epoch{EPOCHS}.pt'
)

##### 03. Test

In [None]:
test_results = test(
    model = qnn, 
    data_loader = test_dataloader,
    device = device
)

In [None]:
print(test_results)