# Imports

In [None]:
# GroundSLASH
from ground_slash.program import Program, Choice
from ground_slash.grounding import Grounder

# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

# PyTorch Geometric
import torch_geometric
from torch_geometric.data import HeteroData, Data, Batch

### Initialize CUDA

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# Program

In [None]:
digits = list(range(10))
n_out = len(digits)

prog_str = fr'''
img(i1). img(i2).

#npp(digit(X), {digits}) :- img(X).

addition(A,B,N1+N2):- digit(A,N1), digit(B,N2), A<B.
addition(B,A,N) :- addition(A,B,N), A<B.
'''

In [None]:
from asn.asn import ASN

asn =ASN.from_string(prog_str)

In [None]:
print(str(asn.prog))

# Dataset

In [None]:
from asn.data.datasets.mnist_addition import MNISTAddition
import torchvision.transforms as tf
from torchvision.datasets import MNIST

# MNIST addition dataset
mnist_add = MNISTAddition(
    n=2,
    root="../data/",
    train=True,
    transform=tf.Compose([tf.ToTensor(), tf.Normalize((0.1307,), (0.3081, ))]), 
    download=True,
    digits=digits,
    seed=1234,
)
# original MNIST dataset
mnist_train = mnist_add.mnist
mnist_test = MNIST(
    root="../data/",
    train=False,
    transform=tf.Compose([tf.ToTensor(), tf.Normalize((0.1307,), (0.3081, ))]),
    download=True,
)

print(len(mnist_train))
print(len(mnist_test))
print(len(mnist_add))

# NPP configuration

In [None]:
from asn.models.alexnet import AlexNet

# create NPP model for digits
model = AlexNet(n_out)
model.to(device)

In [None]:
from asn.solver import NPPContext

# provide models and optimizers for NPPs
# NOTE: only track optimizer for first digit since they share the same network (do not want multiple updates)
asn.configure_NPPs({
    npp_rule: {
        "model": model,
        "optimizer": optim.Adam(model.parameters(), lr=0.005) if not i else None
    }
    for i, npp_rule in enumerate(asn.rg.npp_edges)
})

# Batching

In [None]:
import math

eval_batch_size = 10000
train_batch_size = 512

# data loader for single MNIST digits
mnist_test_loader = DataLoader(mnist_test, batch_size=eval_batch_size)
mnist_train_loader = DataLoader(mnist_train, batch_size=eval_batch_size, shuffle=True)
# data loader for MNIST addition
mnist_addition_loader = DataLoader(mnist_add, batch_size=train_batch_size, shuffle=True)

# Training & Evaluation

In [None]:
def eval_loader(model: nn.Module, loader: DataLoader):

    n_correct = 0
    n_total = 0

    for i, (x, y) in enumerate(loader):

        x = x.to(device)
        y = y.to(device)

        with torch.no_grad():
            y_pred = torch.argmax(model(x), dim=-1)
            n_correct += (y_pred == y).sum()
            n_total += len(y)

    return f"{n_correct}/{n_total}\t({float(n_correct)/n_total})"

In [None]:
from time import perf_counter
from copy import deepcopy
from asn.solver import SolvingContext

# number of epochs
n_epochs = 10

print(f"0/{n_epochs}\t", "\t", eval_loader(model, mnist_test_loader), eval_loader(model, mnist_train_loader))

epoch_times = []

for e in range(n_epochs):

    ts = perf_counter()
    
    # running loss for epoch
    total_loss = torch.tensor(0.0, device=device)
    
    # for each batch
    for x, y in mnist_addition_loader:

        # NPP forward pass
        npp_ctx_dict = asn.npp_forward(
            npp_data={
                npp_rule: (x_i.to(device),)
                for i, (npp_rule, x_i) in enumerate(zip(asn.rg.npp_edges, x))
            },
        )

        # initialize solving context
        solving_ctx = SolvingContext(
            len(y),
            npp_ctx_dict,
        )
        
        # prepare data graph
        graph_block = asn.prepare_block(
            queries=mnist_add.to_queries(y),
            device=device,
        )

        # solve
        graph_block = asn.solve(graph_block)

        # update stable models
        solving_ctx.update_SMs(graph_block)

        # compute loss and gradients
        loss = solving_ctx.npp_loss

        # add loss to running loss
        total_loss += loss.detach()

        # zero gradients
        asn.zero_grad()

        # backward pass
        (-loss).backward()

        # update NPPs
        asn.step()

    epoch_time = perf_counter()-ts
    epoch_times.append(epoch_time)

    # evaluate
    print(f"{e+1}/{n_epochs} ({epoch_time})", total_loss, "\t", eval_loader(model, mnist_test_loader), eval_loader(model, mnist_train_loader))

print(f"Average time per epoch: {sum(epoch_times)/n_epochs}")