In [None]:
import numpy as np
import pandas as pd
from tqdm import tqdm
from discopy.rigid import Spider
from discopro.grammar import tensor
from lambeq import Rewriter, Dataset
from lambeq import QuantumTrainer, SPSAOptimizer
from lambeq import AtomicType, IQPAnsatz, remove_cups
from discopro.anaphora import connect_anaphora_on_top
from lambeq import BobcatParser, NumpyModel, AtomicType

In [None]:
parser = BobcatParser()

In [None]:
rewriter = Rewriter(['auxiliary','connector','coordination','determiner','object_rel_pronoun',
                        'subject_rel_pronoun','postadverb','preadverb','prepositional_phrase'])

In [None]:
N = AtomicType.NOUN
S = AtomicType.SENTENCE
P = AtomicType.PREPOSITIONAL_PHRASE

ansatz = IQPAnsatz({N: 1, S: 1, P:1}, n_layers=1, n_single_qubit_params=3)

In [None]:
def generate_diagram(diagram, pro, ref):

    pro_box_idx = next(i for i, box in enumerate(diagram.boxes) if box.name.casefold() == pro.casefold())
    ref_box_idx = next(i for i, box in enumerate(diagram.boxes) if box.name.casefold() == ref.casefold())
    final_diagram = connect_anaphora_on_top(diagram, pro_box_idx, ref_box_idx)
    rewritten_diagram = rewriter(remove_cups(final_diagram)).normal_form()

    return rewritten_diagram

In [None]:
def anaphoraSent2dig(sentence1, sentence2, pro, ref):
    
    diagram1 = parser.sentence2diagram(sentence1)
    diagram2 = parser.sentence2diagram(sentence2)

    diagram = tensor(diagram1,diagram2)
    diagram = diagram >> Spider(2, 1, S)

    diag = generate_diagram(diagram, pro, ref)

    return diag

In [None]:
def generate_diag_labels(df):

    circuits, labels, diagrams = [],[],[]

    for i, row in tqdm(df.iterrows(), total=len(df)):

        l = row['label']
        ref = row['referent']

        label = [1.0, 0.0] if l == 1 else [0.0, 1.0]
        sent1, sent2, pro = row[['sent1', 'sent2', 'pronoun']]

        try:
            diagram = anaphoraSent2dig(sent1.strip(), sent2.strip(), pro.strip(), ref.strip())
            diagrams.append(diagram)
            discopy_circuit = ansatz(diagram)
            circuits.append(discopy_circuit)
            labels.append(label)
        except Exception as e:
            # Print an error message if an exception occurs
            print("An error occurred:", e)

    return circuits, labels, diagrams

In [None]:
df_train = pd.read_csv('train.csv', index_col=0)
df_val = pd.read_csv('val.csv', index_col=0)
df_test = pd.read_csv('test.csv', index_col=0)

train_circuits, train_labels, train_diagrams = generate_diag_labels(df_train)
val_circuits, val_labels, val_diagrams = generate_diag_labels(df_val)
test_circuits, test_labels, test_diagrams = generate_diag_labels(df_test)

In [None]:
all_circuits = train_circuits + val_circuits + test_circuits
model = NumpyModel.from_diagrams(all_circuits, use_jit=True)

In [None]:
loss = lambda y_hat, y: -np.sum(y * np.log(y_hat)) / len(y)  # binary cross-entropy loss
acc = lambda y_hat, y: np.sum(np.round(y_hat) == np.array(y)) / len(y) / 2  # half due to double-counting
eval_metrics = {"acc": acc}

In [None]:
BATCH_SIZE = 2 #4, 8, 16, 32, 64, 128
EPOCHS = 2000
SEED = 0 # 1, 42, 100, 200

trainer = QuantumTrainer(
    model,
    loss_function=loss,
    epochs=EPOCHS,
    optimizer=SPSAOptimizer,
    optim_hyperparams={'a': 0.1, 'c': 0.06, 'A':0.01*EPOCHS},
    evaluate_functions=eval_metrics,
    evaluate_on_train=True,
    verbose = 'text',
    seed= SEED
)

In [None]:
train_dataset = Dataset(
            train_circuits,
            train_labels,
            batch_size=BATCH_SIZE)

val_dataset = Dataset(val_circuits, val_labels, shuffle=False)
trainer.fit(train_dataset, val_dataset, evaluation_step=1, logging_step=100)

In [None]:
import matplotlib.pyplot as plt

fig, ((ax_tl, ax_tr), (ax_bl, ax_br)) = plt.subplots(2, 2, sharex=True, sharey='row', figsize=(10, 6))
ax_tl.set_title('Training set')
ax_tr.set_title('Development set')
ax_bl.set_xlabel('Iterations')
ax_br.set_xlabel('Iterations')
ax_bl.set_ylabel('Accuracy')
ax_tl.set_ylabel('Loss')

colours = iter(plt.rcParams['axes.prop_cycle'].by_key()['color'])
ax_tl.plot(trainer.train_epoch_costs, color=next(colours))
ax_bl.plot(trainer.train_results['acc'], color=next(colours))
ax_tr.plot(trainer.val_costs, color=next(colours))
ax_br.plot(trainer.val_results['acc'], color=next(colours))

# print test accuracy
test_acc = acc(model(test_circuits), test_labels)
print('Test accuracy:', test_acc)