# Network Initialization

In [1]:
!pip install lark groq -q

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/111.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m111.0/111.0 kB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/130.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m130.8/130.8 kB[0m [31m10.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
from google.colab import drive
from typing import Any, Callable, Literal
from datetime import datetime
from tqdm.auto import tqdm
from PIL import Image, ImageDraw, ImageFont
from sklearn.metrics import roc_auc_score
from groq import Groq
import os
import shutil
import io
import re
import json
import base64
import random
import inspect
import time
import lark
import torch
import pandas as pd
import matplotlib.pyplot as plt

ROOT_DIR = os.getcwd()
DRIVE_DATASETS_DIR = os.path.join(ROOT_DIR, 'drive', 'MyDrive', 'Thesis', 'visudo_pc', 'datasets')
DRIVE_LOGS_DIR = os.path.join(ROOT_DIR, 'drive', 'MyDrive', 'Thesis', 'visudo_pc', 'logs')
DATASETS_DIR = os.path.join(ROOT_DIR, 'data', 'datasets', 'visudo_pc')
GROQ_API_KEY = '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'

In [3]:
drive.mount('/content/drive')
if not os.path.exists(DATASETS_DIR):
    os.makedirs(DATASETS_DIR)
for file_name in os.listdir(DRIVE_DATASETS_DIR):
    if file_name.endswith('.pt'):
        file_path = os.path.join(DRIVE_DATASETS_DIR, file_name)
        shutil.copy(file_path, DATASETS_DIR)

Mounted at /content/drive


# LTN

## FOL Grammar

In [None]:
CONSTANT_TERMINAL = 'constant'
VARIABLE_TERMINAL = 'variable'
FUNCTION_TERMINAL = 'function'
PREDICATE_TERMINAL = 'predicate'
WRAPPER_TERMINAL = 'wrapper'
LOGICAL_NOT_TERMINAL = 'logical_not'
LOGICAL_AND_TERMINAL = 'logical_and'
LOGICAL_OR_TERMINAL = 'logical_or'
IMPLIES_TERMINAL = 'implies'
IFF_TERMINAL = 'iff'
EXISTS_TERMINAL = 'exists'
FORALL_TERMINAL = 'forall'

CONSTANT_SYMBOL = 'C'
VARIABLE_SYMBOL = 'x'
FUNCTION_SYMBOL = 'f'
PREDICATE_SYMBOL = 'P'
WRAPPER_SYMBOL = ''
LOGICAL_NOT_SYMBOL = '!'
LOGICAL_AND_SYMBOL = '&'
LOGICAL_OR_SYMBOL = '|'
IMPLIES_SYMBOL = 'implies'
IFF_SYMBOL = 'iff'
EXISTS_SYMBOL = 'exists'
FORALL_SYMBOL = 'forall'

FOL_GRAMMAR = f'''
//// Explanations ////

    // {CONSTANT_TERMINAL} identifiers always start with "{CONSTANT_SYMBOL}"
    // {VARIABLE_TERMINAL} identifiers always start with "{VARIABLE_SYMBOL}"
    // {FUNCTION_TERMINAL} identifiers always start with "{FUNCTION_SYMBOL}"
    // {PREDICATE_TERMINAL} identifiers always start with "{PREDICATE_SYMBOL}"
    // wrapper symbol is "{WRAPPER_SYMBOL}(" and ")"
    // negation symbol is "{LOGICAL_NOT_SYMBOL}"
    // conjunction symbol is "{LOGICAL_AND_SYMBOL}"
    // disjunction symbol is "{LOGICAL_OR_SYMBOL}"
    // implication symbol is "{IMPLIES_SYMBOL}"
    // equivalence symbol is "{IFF_SYMBOL}"
    // universal quantifier symbol is "{FORALL_SYMBOL}"
    // existential quantifier symbol is "{EXISTS_SYMBOL}"

//// Initialization ////

    // imports
     %import common.WS
     %ignore WS

    // entry point
    ?start: expression

//// Term-Level Terminal Definitions  ////

    // Tree Structure:
    // term
    // ├─atom
    // │ └─{CONSTANT_TERMINAL}, {VARIABLE_TERMINAL}
    // └─mapper
    //   └─{FUNCTION_TERMINAL}

    // Abstract Terminal (no precedence)
    ?term: {CONSTANT_TERMINAL} | {VARIABLE_TERMINAL} | {FUNCTION_TERMINAL}

    // Concrete Terminals (no precedence)
    {CONSTANT_TERMINAL}: /{CONSTANT_SYMBOL}[a-z0-9_]*/
    {VARIABLE_TERMINAL}: /{VARIABLE_SYMBOL}[a-z0-9_]*/
    {FUNCTION_TERMINAL}: /{FUNCTION_SYMBOL}[a-z0-9_]*/ "(" term ("," term)* ")"

//// Expression-Level Terminal Definitions ////

    // Tree Structure:
    // expression
    // ├─evaluator
    // │ └─{PREDICATE_TERMINAL}
    // ├─unary_connective
    // │ └─{WRAPPER_TERMINAL}, {LOGICAL_NOT_TERMINAL}
    // ├─binary_connective
    // │ └─{LOGICAL_AND_TERMINAL}, {LOGICAL_OR_TERMINAL}, {IMPLIES_TERMINAL}, {IFF_TERMINAL}
    // └─quantifier
    //   └─{EXISTS_TERMINAL}, {FORALL_TERMINAL}

    // Abstract Terminal (ascending precedence)
    ?expression: level_0
    ?level_0: level_1 | {EXISTS_TERMINAL} | {FORALL_TERMINAL}
    ?level_1: level_2 | {IFF_TERMINAL} | {IMPLIES_TERMINAL}
    ?level_2: level_3 | {LOGICAL_OR_TERMINAL}
    ?level_3: level_4 | {LOGICAL_AND_TERMINAL}
    ?level_4: level_5 | {LOGICAL_NOT_TERMINAL} | {WRAPPER_TERMINAL}
    ?level_5: predicate

    // Concrete Terminals (ascending precedence)
    {EXISTS_TERMINAL}: "{EXISTS_SYMBOL}" {VARIABLE_TERMINAL} ("," {VARIABLE_TERMINAL})* expression
    {FORALL_TERMINAL}: "{FORALL_SYMBOL}" {VARIABLE_TERMINAL} ("," {VARIABLE_TERMINAL})* expression
    {IFF_TERMINAL}: level_1 "{IFF_SYMBOL}" level_2
    {IMPLIES_TERMINAL}: level_1 "{IMPLIES_SYMBOL}" level_2
    {LOGICAL_OR_TERMINAL}: level_2 "{LOGICAL_OR_SYMBOL}" level_3
    {LOGICAL_AND_TERMINAL}: level_4 "{LOGICAL_AND_SYMBOL}" level_3
    {LOGICAL_NOT_TERMINAL}: "{LOGICAL_NOT_SYMBOL}" level_4
    {WRAPPER_TERMINAL}: "{WRAPPER_SYMBOL}(" expression ")"
    {PREDICATE_TERMINAL}: /{PREDICATE_SYMBOL}[a-z0-9_]*/ "(" term ("," term)* ")"
'''

## Groundings

In [None]:
class Grd():

    # Implicit Super Grounding
    class Base():
        def __init__(self, description:str, content:None|str|torch.Tensor|torch.nn.Module|Callable[..., torch.Tensor], hyper_arg_dict:dict[str, int|float]) -> None:
            self.description = description
            self.content = content
            self.hyper_arg_dict = hyper_arg_dict
        def __repr__(self) -> str:
            return self.__str__()
        def __str__(self) -> str:
            return self.description
        def __call__(self, *args:Any) -> Any:
            raise NotImplementedError()

    # Explicit Groundings
    class Empty(Base):
        def __init__(self) -> None:
            super().__init__('empty', None, {})
        def __call__(self) -> None:
            return self.content

    class Command(Base):
        def __init__(self, command:str) -> None:
            super().__init__('command', command, {})
        def __call__(self) -> str:
            return self.content

    class Value(Base):
        def __init__(self, tensor:torch.Tensor) -> None:
            shape = 'scalar' if tensor.ndim == 0 else f'{tensor.shape[0]}x0' if tensor.ndim == 1 else 'x'.join(str(dim) for dim in tensor.shape)
            trainability = 'trainable' if tensor.requires_grad else 'non-trainable'
            super().__init__(f'tensor[{shape}, {trainability}]', tensor, {})
        def __call__(self) -> torch.Tensor:
            return self.content

    class Network(Base):
        def __init__(self, network:torch.nn.Module) -> None:
            tot_params = sum(tensor.numel() for tensor in network.parameters())
            tot_trainable_params = sum(tensor.numel() for tensor in network.parameters() if tensor.requires_grad)
            super().__init__(f'network[{tot_params} params, {tot_trainable_params} trainable]', network, {})
        def __call__(self, *args:torch.Tensor) -> torch.Tensor:
            return self.content(*args)

    class Routine(Base):
        @classmethod
        def wrap(cls, description:str) -> Callable[[Callable[..., torch.Tensor]], 'Grd.Routine']:
            def decorator(routine: Callable[..., torch.Tensor]) -> Grd.Routine:
                return Grd.Routine(staticmethod(routine), description)
            return decorator
        def __init__(self, routine:Callable[..., torch.Tensor], description:str):
            signature = inspect.signature(routine).parameters
            assert all(arg_value.annotation is not inspect._empty for arg_value in signature.values()), "The routine must be type-hinted!"
            hyper_arg_dict = dict[str, int|float]()
            for arg_name, arg_value in signature.items():
                if arg_value.annotation in (int, float):
                    hyper_arg_dict[arg_name] = arg_value.default if arg_value.default is not inspect._empty else None
            if len(hyper_arg_dict) > 0:
                hyper_arg = f', {len(hyper_arg_dict)} hyper-arg' + ('s' if len(hyper_arg_dict) > 1 else '') + f', {sum(hyper_arg_value is None for hyper_arg_value in hyper_arg_dict.values())} empty'
            else:
                hyper_arg = ''
            super().__init__(f'{routine.__name__}[{description}{hyper_arg}]', routine, hyper_arg_dict)
        def __call__(self, *args:torch.Tensor) -> torch.Tensor:
            return self.content(*args, **self.hyper_arg_dict)

### Tests

In [None]:
a = 'hi'
b = Grd.Command(a)
print('command')
print('description', b.description)
print('content', b.content, a is b.content)
print('arg_dict', b.hyper_arg_dict)
print('call', b())
print()

a = torch.tensor([1, 2, 3])
b = Grd.Value(a)
print('value')
print('description', b.description)
print('content', b.content, a is b.content)
print('arg_dict', b.hyper_arg_dict)
print('call', b())
print()

a = torch.nn.Linear(3, 4)
b = Grd.Network(a)
c = torch.tensor([1.0, 2.0, 3.0])
print('network')
print('description', b.description)
print('content', b.content, a is b.content)
print('arg_dict', b.hyper_arg_dict)
print('call', b(c))
print()

def a(x:torch.Tensor) -> torch.Tensor:
    return x
b = Grd.Routine(a, 'x + 1')
c = torch.tensor([1.0, 2.0, 3.0])
print('routine')
print('description', b.description)
print('content', b.content, a is b.content)
print('arg_dict', b.hyper_arg_dict)
print('call', b(c))
print()

def a(reduction_list:list[int], x:torch.Tensor, y:int, z:float=6.5) -> torch.Tensor:
    return x[reduction_list] + y
b = Grd.Routine(a, 'x + 1')
c = torch.tensor([1.0, 2.0, 3.0])
print('routine')
print('description', b.description)
print('content', b.content, a is b.content)
print('arg_dict', b.hyper_arg_dict)
b.hyper_arg_dict['y'] = 1
print('call', b([1,2], c))
print()

command
description command
content hi True
arg_dict {}
call hi

value
description tensor[3x0, non-trainable]
content tensor([1, 2, 3]) True
arg_dict {}
call tensor([1, 2, 3])

network
description network[16 params, 16 trainable]
content Linear(in_features=3, out_features=4, bias=True) True
arg_dict {}
call tensor([-0.0422,  0.9244,  0.0816, -1.4042], grad_fn=<ViewBackward0>)

routine
description a[x + 1]
content <function a at 0x7cb6b7bda480> True
arg_dict {}
call tensor([1., 2., 3.])

routine
description a[x + 1, 2 hyper-args, 1 empty]
content <function a at 0x7cb6b7bdaf20> True
arg_dict {'y': None, 'z': 6.5}
call tensor([3., 4.])



## Blocks

In [None]:
class Blk():

    # Abstract Super Block
    class Base():
        PRINT_PRIORITY = 0
        TERMINAL:str|None = None
        def __init__(self, name:str, symbol:str, *children:'Blk.Base') -> None:
            self.name = name
            self.symbol = symbol
            self.children = children
            self.current_grd:Grd.Base = Grd.Empty()
            self.default_grd_dict = {k: v for k, v in self.__class__.__dict__.items() if isinstance(v, Grd.Routine)}
            self.global_blk_list = list['Blk.Base']()
            self.global_domain_list = list[str]()
            self.involved_domain_set = set[str]()
            self.arg_shape_dict = dict[str, list[str]]()
            self.inp_shape_list = list[str]()
            self.out_shape_list = list[str]()
        def __repr__(self) -> str:
            return self.__str__()
        def __str__(self) -> str:
            raise NotImplementedError()
        def __call__(self, **value_dict:torch.Tensor) -> torch.Tensor:
            raise NotImplementedError()
        def setup_blks(self, global_blk_list:list['Blk.Base']) -> None:
            for child in self.children:
                child.setup_blks(global_blk_list)
            global_blk_list.append(self)
            self.global_blk_list = global_blk_list
        def setup_domains(self, global_domain_list:list[str]) -> None:
            for child in self.children:
                child.setup_domains(global_domain_list)
            if isinstance(self, Blk.Quantifier):
                for variable in self.children[:-1]:
                    if variable.name not in global_domain_list:
                        global_domain_list.append(variable.name)
            global_domain_list.sort()
            self.global_domain_list = global_domain_list
        def setup_shapes(self) -> None:
            for child in self.children:
                child.setup_shapes()
            if isinstance(self, Blk.Atom):
                self.out_shape_list.append('B*')
                if self.name in self.global_domain_list:
                    self.involved_domain_set = {self.name}
                    self.out_shape_list.append(f'D{self.name}')
                self.out_shape_list.append(f'E{self.name}')
            elif isinstance(self, Blk.Mapper):
                for term in self.children:
                    self.involved_domain_set.update(term.involved_domain_set)
                    self.arg_shape_dict[term.name] = term.out_shape_list
                self.out_shape_list.append('B*')
                for domain in self.global_domain_list:
                    self.out_shape_list.append(f'D{domain}' if domain in self.involved_domain_set else '1')
                self.out_shape_list.append(f'E{self.name}')
            elif isinstance(self, Blk.BinaryConnective):
                left, right = self.children
                self.involved_domain_set.update(left.involved_domain_set)
                self.involved_domain_set.update(right.involved_domain_set)
                self.arg_shape_dict['left_tensor'] = ['B*'] + [f'D{domain}*' for domain in self.global_domain_list]
                self.arg_shape_dict['right_tensor'] = ['B*'] + [f'D{domain}*' for domain in self.global_domain_list]
                self.out_shape_list = ['B*'] + [f'D{domain}*' for domain in self.global_domain_list]
            elif isinstance(self, Blk.UnaryConnective):
                body = self.children[0]
                self.involved_domain_set.update(body.involved_domain_set)
                self.arg_shape_dict['tensor'] = ['B*'] + [f'D{domain}*' for domain in self.global_domain_list]
                self.out_shape_list = ['B*'] + [f'D{domain}*' for domain in self.global_domain_list]
            elif isinstance(self, Blk.Quantifier):
                body = self.children[-1]
                self.involved_domain_set.update(self.global_domain_list)
                for variable in self.children[:-1]:
                    self.involved_domain_set.remove(variable.name)
                self.arg_shape_dict['tensor'] = ['B*'] + [f'D{domain}*' for domain in self.global_domain_list]
                self.out_shape_list = ['B*'] + [f'D{domain}**' for domain in self.global_domain_list]
            elif isinstance(self, Blk.Evaluator):
                for term in self.children:
                    self.involved_domain_set.update(term.involved_domain_set)
                    self.arg_shape_dict[term.name] = term.out_shape_list
                self.out_shape_list.append('B*')
                for domain in self.global_domain_list:
                    self.out_shape_list.append(f'D{domain}' if domain in self.involved_domain_set else '1')
        def ground(self, **context_dict:Grd.Command|Grd.Value|Grd.Network|Grd.Routine) -> None:
            if self.name in context_dict:
                temporary_ground = context_dict[self.name]
                if isinstance(temporary_ground, Grd.Command):
                    self.current_grd = self.default_grd_dict[temporary_ground()]
                else:
                    self.current_grd = temporary_ground
            for child in self.children:
                child.ground(**context_dict)
        def control(self, **hyper_arg_dict:dict[str, int|float]) -> None:
            if self.name in hyper_arg_dict:
                for arg_name, arg_value in hyper_arg_dict[self.name].items():
                    if arg_name in self.current_grd.hyper_arg_dict:
                        self.current_grd.hyper_arg_dict[arg_name] = arg_value
            for child in self.children:
                child.control(**hyper_arg_dict)

    # Abstract High-Level Blocks
    class Term(Base):
        def __init__(self, name:str, symbol:str, *children:'Blk.Base') -> None:
            super().__init__(name, symbol, *children)

    class Expression(Base):
        def __init__(self, name:str, symbol:str, *children:'Blk.Base') -> None:
            super().__init__(name, symbol, *children)

    # Abstract Medium-Level Blocks
    class Atom(Term):
        PRINT_PRIORITY = 1
        def __init__(self, name:str, symbol:str) -> None:
            super().__init__(name, symbol)
            self.current_grd:Grd.Empty|Grd.Value = Grd.Empty()
        def __str__(self) -> str:
            return self.name
        def __call__(self, **value_dict:torch.Tensor) -> torch.Tensor:
            if isinstance(self.current_grd, Grd.Empty):
                return value_dict[self.name]
            return self.current_grd()

    class Mapper(Term):
        PRINT_PRIORITY = 2
        def __init__(self, name:str, symbol:str, *terms:'Blk.Term') -> None:
            super().__init__(name, symbol, *terms)
            self.current_grd:Grd.Empty|Grd.Network|Grd.Routine = Grd.Empty()
        def __str__(self) -> str:
            return self.name + '(' + ', '.join(str(child) for child in self.children) + ')'
        def __call__(self, **value_dict:torch.Tensor) -> torch.Tensor:
            return self.current_grd(*[term(**value_dict) for term in self.children])

    class Evaluator(Expression):
        PRINT_PRIORITY = 3
        def __init__(self, name:str, symbol:str, *terms:'Blk.Term') -> None:
            super().__init__(name, symbol, *terms)
            self.current_grd:Grd.Empty|Grd.Network|Grd.Routine = Grd.Empty()
        def __str__(self) -> str:
            return self.name + '(' + ', '.join(str(child) for child in self.children) + ')'
        def __call__(self, **value_dict:torch.Tensor) -> torch.Tensor:
            return self.current_grd(*[term(**value_dict) for term in self.children])

    class UnaryConnective(Expression):
        PRINT_PRIORITY = 4
        def __init__(self, name:str, symbol:str, default_grd:Grd.Routine, body:'Blk.Expression') -> None:
            super().__init__(name, symbol, body)
            self.current_grd = default_grd
        def __str__(self) -> str:
            return self.symbol + '(' + str(self.children[0]) + ')'
        def __call__(self, **value_dict:torch.Tensor) -> torch.Tensor:
            body = self.children[0]
            return self.current_grd(body(**value_dict))

    class BinaryConnective(Expression):
        PRINT_PRIORITY = 5
        def __init__(self, name:str, symbol:str, default_grd:Grd.Routine, left:'Blk.Expression', right:'Blk.Expression') -> None:
            super().__init__(name, symbol, left, right)
            self.current_grd = default_grd
        def __str__(self) -> str:
            return str(self.children[0]) + ' ' + self.symbol + ' ' + str(self.children[1])
        def __call__(self, **value_dict:torch.Tensor) -> torch.Tensor:
            left, right = self.children
            return self.current_grd(left(**value_dict), right(**value_dict))

    class Quantifier(Expression):
        PRINT_PRIORITY = 6
        def __init__(self, name:str, symbol:str, default_grd:Grd.Routine, body:'Blk.Expression', *variables:'Blk.Variable') -> None:
            super().__init__(name, symbol, *variables, body)
            self.current_grd = default_grd
        def __str__(self) -> str:
            return self.symbol + '[' + ', '.join(str(child) for child in self.children[:-1]) + '](' + str(self.children[-1]) + ')'
        def __call__(self, **value_dict:torch.Tensor) -> torch.Tensor:
            variable_list = self.children[:-1]
            body = self.children[-1]
            reduction_dim_list = [self.global_domain_list.index(variable.name) + 1 for variable in variable_list]
            return self.current_grd(reduction_dim_list, body(**value_dict))

    # Concrete Blocks
    class Constant(Atom):
        TERMINAL = CONSTANT_TERMINAL
        def __init__(self, in_list:list[lark.Token]) -> None:
            super().__init__(in_list[0].value, CONSTANT_SYMBOL)

    class Variable(Atom):
        TERMINAL = VARIABLE_TERMINAL
        def __init__(self, in_list:list[lark.Token]) -> None:
            super().__init__(in_list[0].value, VARIABLE_SYMBOL)

    class Function(Mapper):
        TERMINAL = FUNCTION_TERMINAL
        def __init__(self, in_list:list['lark.Token|Blk.Term']) -> None:
            super().__init__(in_list[0].value, FUNCTION_SYMBOL, *in_list[1:])

    class Predicate(Evaluator):
        TERMINAL = PREDICATE_TERMINAL
        def __init__(self, in_list:list['lark.Token|Blk.Term']) -> None:
            super().__init__(in_list[0].value, PREDICATE_SYMBOL, *in_list[1:])

    class Wrapper(UnaryConnective):
        TERMINAL = WRAPPER_TERMINAL
        def __init__(self, in_list:list['Blk.Expression']) -> None:
            super().__init__(WRAPPER_TERMINAL, WRAPPER_SYMBOL, self.identity, in_list[0])
        @Grd.Routine.wrap('x')
        def identity(tensor:torch.Tensor) -> torch.Tensor:
            return tensor

    class LogicalNot(UnaryConnective):
        TERMINAL = LOGICAL_NOT_TERMINAL
        def __init__(self, in_list:list['Blk.Expression']) -> None:
            super().__init__(LOGICAL_NOT_TERMINAL, LOGICAL_NOT_SYMBOL, self.complementation, in_list[0])
        @Grd.Routine.wrap('1 - x')
        def complementation(tensor:torch.Tensor) -> torch.Tensor:
            return 1.0 - tensor

    class LogicalAnd(BinaryConnective):
        TERMINAL = LOGICAL_AND_TERMINAL
        def __init__(self, in_list:list['Blk.Expression']) -> None:
            super().__init__(LOGICAL_AND_TERMINAL, LOGICAL_AND_SYMBOL, self.lukasiewicz, in_list[0], in_list[1])
        @Grd.Routine.wrap('max(0, x1 + x2 - 1)')
        def lukasiewicz(left_tensor:torch.Tensor, right_tensor:torch.Tensor) -> torch.Tensor:
            return torch.maximum(torch.tensor(0.0, device=left_tensor.device), left_tensor + right_tensor - 1.0)
        @Grd.Routine.wrap('min(x1, x2)')
        def godel(left_tensor:torch.Tensor, right_tensor:torch.Tensor) -> torch.Tensor:
            return torch.minimum(left_tensor, right_tensor)
        @Grd.Routine.wrap('x1 * x2')
        def goguen(left_tensor:torch.Tensor, right_tensor:torch.Tensor) -> torch.Tensor:
            return left_tensor * right_tensor

    class LogicalOr(BinaryConnective):
        TERMINAL = LOGICAL_OR_TERMINAL
        def __init__(self, in_list:list[Any]) -> None:
            super().__init__(LOGICAL_OR_TERMINAL, LOGICAL_OR_SYMBOL, self.lukasiewicz, in_list[0], in_list[1])
        @Grd.Routine.wrap('min(1, x1 + x2)')
        def lukasiewicz(left_tensor:torch.Tensor, right_tensor:torch.Tensor) -> torch.Tensor:
            return torch.minimum(torch.tensor(1.0, device=left_tensor.device), left_tensor + right_tensor)
        @Grd.Routine.wrap('max(x1, x2)')
        def godel(left_tensor:torch.Tensor, right_tensor:torch.Tensor) -> torch.Tensor:
            return torch.maximum(left_tensor, right_tensor)
        @Grd.Routine.wrap('1 - (1 - x1) * (1 - x2)')
        def goguen(left_tensor:torch.Tensor, right_tensor:torch.Tensor) -> torch.Tensor:
            return 1.0 - (1.0 - left_tensor) * (1.0 - right_tensor)

    class Implies(BinaryConnective):
        TERMINAL = IMPLIES_TERMINAL
        def __init__(self, in_list:list['Blk.Expression']) -> None:
            super().__init__(IMPLIES_TERMINAL, IMPLIES_SYMBOL, self.lukasiewicz, in_list[0], in_list[1])
        @Grd.Routine.wrap('min(1, 1 - x1 + x2)')
        def lukasiewicz(left_tensor:torch.Tensor, right_tensor:torch.Tensor) -> torch.Tensor:
            return torch.minimum(torch.tensor(1.0, device=left_tensor.device), 1 - left_tensor + right_tensor)
        @Grd.Routine.wrap('1 if x1 <= x2 else x2')
        def godel(left_tensor:torch.Tensor, right_tensor:torch.Tensor) -> torch.Tensor:
            return torch.where(left_tensor <= right_tensor, torch.tensor(1.0, device=left_tensor.device), right_tensor)
        @Grd.Routine.wrap('1 if x1 <=x2 else x2 / x1')
        def goguen(left_tensor:torch.Tensor, right_tensor:torch.Tensor) -> torch.Tensor:
            return torch.where(left_tensor <= right_tensor, torch.tensor(1.0, device=left_tensor.device), right_tensor / torch.clamp(left_tensor, min=1e-6))
        @Grd.Routine.wrap('max(1 - x1, x2)')
        def kleene_dienes(left_tensor:torch.Tensor, right_tensor:torch.Tensor) -> torch.Tensor:
            return torch.maximum(1 - left_tensor, right_tensor)
        @Grd.Routine.wrap('1 - x1 + x1 * x2')
        def reichenbach(left_tensor:torch.Tensor, right_tensor:torch.Tensor) -> torch.Tensor:
            return 1.0 - left_tensor + left_tensor * right_tensor

    class Iff(BinaryConnective):
        TERMINAL = IFF_TERMINAL
        def __init__(self, in_list:list['Blk.Expression']) -> None:
            super().__init__(IFF_TERMINAL, IFF_SYMBOL, self.linear_similarity, in_list[0], in_list[1])
        @Grd.Routine.wrap('1 - |x1 - x2|')
        def linear_similarity(left_tensor:torch.Tensor, right_tensor:torch.Tensor) -> torch.Tensor:
            return 1.0 - torch.abs(left_tensor - right_tensor)

    class ForAll(Quantifier):
        TERMINAL = FORALL_TERMINAL
        def __init__(self, in_list:list['Blk.Variable|Blk.Expression']) -> None:
            super().__init__(FORALL_TERMINAL, FORALL_SYMBOL, self.lukasiewicz, in_list[-1], *in_list[:-1])
        @Grd.Routine.wrap('max(0, sum(xi) - n + 1)')
        def lukasiewicz(reduction_dim_list:list[int], tensor:torch.Tensor) -> torch.Tensor:
            sum_tensor = torch.sum(tensor, dim=reduction_dim_list, keepdim=True)
            n = torch.prod(torch.tensor([tensor.size(dim) for dim in reduction_dim_list], device=tensor.device))
            return torch.maximum(torch.tensor(0.0, device=tensor.device), sum_tensor - n + 1)
        @Grd.Routine.wrap('min(xi)')
        def godel(reduction_dim_list:list[int], tensor:torch.Tensor) -> torch.Tensor:
            return torch.amin(tensor, dim=reduction_dim_list, keepdim=True)
        @Grd.Routine.wrap('prod(xi)')
        def goguen(reduction_dim_list:list[int], tensor:torch.Tensor) -> torch.Tensor:
            prod_tensor = tensor
            for dim in sorted(reduction_dim_list, reverse=True):
                prod_tensor = torch.prod(prod_tensor, dim=dim, keepdim=True)
            return prod_tensor
        @Grd.Routine.wrap('1 - mean((1 - xi) ** p) ** (1 / p)')
        def power_mean(reduction_dim_list:list[int], tensor:torch.Tensor, p:int=2) -> torch.Tensor:
            return 1.0 - torch.mean((1.0 - tensor) ** p, dim=reduction_dim_list, keepdim=True) ** (1 / p)

    class Exists(Quantifier):
        TERMINAL = EXISTS_TERMINAL
        def __init__(self, in_list:list['Blk.Variable|Blk.Expression']) -> None:
            super().__init__(EXISTS_TERMINAL, EXISTS_SYMBOL, self.lukasiewicz, in_list[-1], *in_list[:-1])
        @Grd.Routine.wrap('min(1, sum(xi))')
        def lukasiewicz(reduction_dim_list:list[int], tensor:torch.Tensor) -> torch.Tensor:
            sum_tensor = torch.sum(tensor, dim=reduction_dim_list, keepdim=True)
            return torch.minimum(torch.tensor(1.0, device=tensor.device), sum_tensor)
        @Grd.Routine.wrap('max(xi)')
        def godel(reduction_dim_list:list[int], tensor:torch.Tensor) -> torch.Tensor:
            return torch.amax(tensor, dim=reduction_dim_list, keepdim=True)
        @Grd.Routine.wrap('1 - prod(1 - xi)')
        def goguen(reduction_dim_list:list[int], tensor:torch.Tensor) -> torch.Tensor:
            prod_tensor = 1.0 - tensor
            for dim in sorted(reduction_dim_list, reverse=True):
                prod_tensor = torch.prod(prod_tensor, dim=dim, keepdim=True)
            return 1.0 - prod_tensor
        @Grd.Routine.wrap('mean(xi ** p) ** (1 / p)')
        def power_mean(reduction_dim_list:list[int], tensor:torch.Tensor, p:int=1) -> torch.Tensor:
            return torch.mean(tensor ** p, dim=reduction_dim_list, keepdim=True) ** (1 / p)

## Builder

In [None]:
class LTNBuilder():

    # Explicit Formula
    class Formula():
        def __init__(self, fol_rule:str, fol_tree:lark.Tree, ltn_expression:'Blk.Expression') -> None:
            self._fol_rule = fol_rule
            self._fol_tree = fol_tree
            self._ltn_expression = ltn_expression
        def __call__(self, **value_dict:torch.Tensor) -> torch.Tensor:
            output = self._ltn_expression(**value_dict).squeeze()
            if output.ndim == 0:
                output = output.unsqueeze(0)
            return output
        def get_parameters(self) -> list[torch.nn.Parameter]:
            parameter_list = list[torch.nn.Parameter]()
            for blk in self._ltn_expression.global_blk_list:
                if isinstance(blk.current_grd, Grd.Value):
                    parameter_list.append(blk.current_grd.content)
                elif isinstance(blk.current_grd, Grd.Network):
                    parameter_list += list(blk.current_grd.content.parameters())
            return parameter_list
        def set_config(self, mode:Literal['device', 'status'], param:str) -> None:
            if mode == 'device':
                for blk in self._ltn_expression.global_blk_list:
                    if isinstance(blk.current_grd, Grd.Value):
                        blk.current_grd.content = blk.current_grd.content.to(param)
                    elif isinstance(blk.current_grd, Grd.Network):
                        blk.current_grd.content = blk.current_grd.content.to(param)
            elif mode == 'status':
                if param == 'train':
                    for blk in self._ltn_expression.global_blk_list:
                        if isinstance(blk.current_grd, Grd.Value):
                            blk.current_grd.content.requires_grad = True
                        elif isinstance(blk.current_grd, Grd.Network):
                            blk.current_grd.content.train()
                elif param == 'eval':
                    for blk in self._ltn_expression.global_blk_list:
                        if isinstance(blk.current_grd, Grd.Value):
                            blk.current_grd.content.requires_grad = False
                        elif isinstance(blk.current_grd, Grd.Network):
                            blk.current_grd.content.eval()
        def ground(self, **raw_context_dict:str|torch.Tensor|torch.nn.Module|Callable[..., torch.Tensor]) -> None:
            context_dict = dict[str, Grd.Command|Grd.Value|Grd.Network|Grd.Routine]()
            for blk_name in raw_context_dict:
                if isinstance(raw_context_dict[blk_name], str):
                    context_dict[blk_name] = Grd.Command(raw_context_dict[blk_name])
                elif isinstance(raw_context_dict[blk_name], torch.Tensor):
                    context_dict[blk_name] = Grd.Value(raw_context_dict[blk_name])
                elif isinstance(raw_context_dict[blk_name], torch.nn.Module):
                    context_dict[blk_name] = Grd.Network(raw_context_dict[blk_name])
                elif isinstance(raw_context_dict[blk_name], Callable):
                    context_dict[blk_name] = Grd.Routine(raw_context_dict[blk_name], 'n/a')
            self._ltn_expression.ground(**context_dict)
        def control(self, **hyper_arg_dict:dict[str, int|float]) -> None:
            self._ltn_expression.control(**hyper_arg_dict)
        def get(self, mode:Literal['raw', 'parsed', 'tree', 'info']) -> str:
            output = None
            if mode == 'raw':
                output = self._fol_rule
            elif mode == 'parsed':
                output = self._ltn_expression
            elif mode == 'tree':
                output = self._fol_tree.pretty()[:-1]
            elif mode == 'info':
                ungrounded_blk_dict = dict[tuple[int, str], str]()
                grounded_blk_dict = dict[tuple[int, str], str]()
                for blk in self._ltn_expression.global_blk_list:
                    if isinstance(blk.current_grd, Grd.Empty):
                        ungrounded_blk_dict[blk.PRINT_PRIORITY, blk.name] = f'  {blk.name} -> {blk.TERMINAL}'
                    else:
                        grounded_blk_dict[blk.PRINT_PRIORITY, blk.name] = f'  {blk.name} -> {blk.current_grd}'
                        if len(blk.default_grd_dict) > 0:
                            grounded_blk_dict[blk.PRINT_PRIORITY, blk.name] += '*, ' + ', '.join(str(grd) for grd in blk.default_grd_dict.values() if grd != blk.current_grd)
                ungrounded_blk_dict = dict(sorted(ungrounded_blk_dict.items(), key=lambda item: item[0]))
                grounded_blk_dict = dict(sorted(grounded_blk_dict.items(), key=lambda item: item[0]))
                output = ''
                if len(ungrounded_blk_dict) > 0:
                    output += '\nUngrounded Symbols:\n' + '\n'.join(ungrounded_blk_dict.values())
                if len(grounded_blk_dict) > 0:
                    output += '\nGrounded Symbols:\n' + '\n'.join(grounded_blk_dict.values())
                output = output[1:] if output.startswith('\n') else output
                output += (
                    '\nFootnotes:'
                    '\n  * -> current grounding'
                )
            elif mode == 'shapes':
                blk_dict = dict[tuple[int, str], str]()
                for blk in self._ltn_expression.global_blk_list:
                    shapes = list[str]()
                    if len(blk.arg_shape_dict) > 0:
                        argument_shape_list = list[str]()
                        for arg in blk.arg_shape_dict:
                            argument_shape_list.append(f'{arg}[' + ', '.join(blk.arg_shape_dict[arg]) + ']')
                        shapes.append('  inp shape -> ' + ', '.join(argument_shape_list))
                    if len(blk.inp_shape_list) > 0:
                        shapes.append('  inp shape -> ' + ', '.join(blk.inp_shape_list))
                    if len(blk.out_shape_list) > 0:
                        shapes.append('  out shape -> ' + ', '.join(blk.out_shape_list))
                    blk_dict[blk.PRINT_PRIORITY, blk.name] = '\n'.join(shapes)
                blk_dict = dict(sorted(blk_dict.items(), key=lambda item: item[0]))
                output = '\n'.join(f'{blk_name}:\n{blk_shapes}' for (_, blk_name), blk_shapes in blk_dict.items())
                output += (
                    '\nFootnotes:'
                    '\n  B -> batch size'
                    '\n  Dx -> domain size of `x`'
                    '\n  Ex -> extra dimensions of `x`'
                    '\n  * -> the dimension can also be equal to `1` but not missing'
                    '\n  ** -> the dimension is equal to `1` if its corresponding variable is involved in the quantifier'
                )
            return output

    def __init__(self, fol_grammar=FOL_GRAMMAR) -> None:
        self._fol_parser = lark.Lark(fol_grammar)
        self._fol_transformer = lark.Transformer()
        for Block in Blk.__dict__.values():
            if isinstance(Block, type) and issubclass(Block, Blk.Base) and Block.TERMINAL is not None:
                setattr(self._fol_transformer, Block.TERMINAL, Block)

    def make_formula(self, fol_rule:str) -> Formula:
        fol_tree = self._fol_parser.parse(fol_rule)
        ltn_expression:Blk.Expression = self._fol_transformer.transform(fol_tree)
        ltn_expression.setup_blks(list['Blk.Base']())
        ltn_expression.setup_domains(list[str]())
        ltn_expression.setup_shapes()
        ltn_formula = LTNBuilder.Formula(fol_rule, fol_tree, ltn_expression)
        return ltn_formula

### Tests

In [None]:
fol_rule = 'forall x P(x)'
ltn_formula = LTNBuilder().make_formula(fol_rule)

def foo(x:torch.Tensor, y:int) -> torch.Tensor:
    return x + y
ltn_formula.ground(
    P=foo,
    forall='power_mean'
)
ltn_formula.control(
    P={'y': 1}
)

print('Raw')
print(ltn_formula.get('raw'))
print()

print('Parsed')
print(ltn_formula.get('parsed'))
print()

print('Tree')
print(ltn_formula.get('tree'))
print()

print('Info')
print(ltn_formula.get('info'))
print()

print('Shapes')
print(ltn_formula.get('shapes'))
print()

y = ltn_formula(x=torch.tensor([[1.0], [2.0]]))
print(y)

ltn_formula.control(
    P={'y': 5}
)
y = ltn_formula(x=torch.tensor([[1.0], [2.0]]))
print(y)

Raw
forall x P(x)

Parsed
forall[x](P(x))

Tree
forall
  variable	x
  predicate
    P
    variable	x

Info
Ungrounded Symbols:
  x -> variable
Grounded Symbols:
  P -> foo[n/a, 1 hyper-arg, 1 empty]
  forall -> power_mean[1 - mean((1 - xi) ** p) ** (1 / p), 1 hyper-arg, 0 empty]*, lukasiewicz[max(0, sum(xi) - n + 1)], godel[min(xi)], goguen[prod(xi)]
Footnotes:
  * -> current grounding

Shapes
x:
  out shape -> B*, Dx, Ex
P:
  inp shape -> x[B*, Dx, Ex]
  out shape -> B*, Dx
forall:
  inp shape -> tensor[B*, Dx*]
  out shape -> B*, Dx**
Footnotes:
  B -> batch size
  Dx -> domain size of `x`
  Ex -> extra dimensions of `x`
  * -> the dimension can also be equal to `1` but not missing
  ** -> the dimension is equal to `1` if its corresponding variable is involved in the quantifier

tensor([nan, nan])
tensor([nan, nan])


# Platform

## Groundings Tools

In [None]:
class VisualEncoder(torch.nn.Module):

    def __init__(self,
                 device:str,
                 symbol_size,
                 cnn_hidden_dims:tuple[int, ...],
                 embed_dim:int,
                 drop_prob:float) -> None:
        super().__init__()
        self.device = device
        self.symbol_size = symbol_size
        self.cnn_hidden_dims = cnn_hidden_dims
        self.embed_dim = embed_dim
        self.drop_prob = drop_prob
        cnn_list = torch.nn.ModuleList()
        last_cnn_hidden_dim = 1
        pooled_side = symbol_size
        for cnn_hidden_dim in cnn_hidden_dims:
            cnn_list.append(torch.nn.Conv2d(in_channels=last_cnn_hidden_dim, out_channels=cnn_hidden_dim, kernel_size=3, stride=1, padding=1))
            cnn_list.append(torch.nn.ReLU())
            cnn_list.append(torch.nn.BatchNorm2d(num_features=cnn_hidden_dim))
            cnn_list.append(torch.nn.MaxPool2d(kernel_size=2, stride=2))
            last_cnn_hidden_dim = cnn_hidden_dim
            pooled_side = pooled_side // 2
        self.cnn = torch.nn.Sequential(*cnn_list)
        self.cnn_projection = torch.nn.Sequential(
            torch.nn.Linear(in_features=last_cnn_hidden_dim * pooled_side * pooled_side, out_features=embed_dim - 2),
            torch.nn.Dropout(p=drop_prob)
        )
        self.to(device)

    def forward(self, symbols:torch.Tensor) -> torch.Tensor:
        '''
        inputs:
            symbols: Tensor[batch_size, dimension, dimension, symbol_size, symbol_size][cpu][torch.uint8]
        outputs:
            embeds: Tensor[batch_size, total_symbols, embed_dim][device][torch.float32]
        '''
        batch_size, dimension, _, symbol_size, _ = symbols.shape
        total_symbols = dimension ** 2
        symbols = symbols.reshape(batch_size, total_symbols, symbol_size, symbol_size) # Tensor[batch_size, total_symbols, symbol_size, symbol_size][device][torch.uint8]
        indices = torch.cartesian_prod(torch.arange(dimension), torch.arange(dimension)).expand(batch_size, -1, 2).to(torch.uint8) # Tensor[batch_size, total_symbols, 2][cpu][torch.uint8]
        symbols = symbols.to(self.device).to(torch.float32) / 255 # Tensor[batch_size, total_symbols, symbol_size, symbol_size][device][torch.float32]
        indices = indices.to(self.device).to(torch.float32) # Tensor[batch_size, total_symbols, 2][device][torch.float32]
        embeds = symbols.view(-1, 1, self.symbol_size, self.symbol_size) # Tensor[batch_size*total_symbols, 1, symbol_size, symbol_size][device][torch.float32]
        embeds = self.cnn(embeds) # Tensor[batch_size*total_symbols, last_cnn_hidden_dim, pooled_side, pooled_side][device][torch.float32]
        embeds = embeds.reshape(batch_size, total_symbols, -1) # Tensor[batch_size, total_symbols, last_cnn_hidden_dim*pooled_side*pooled_side][device][torch.float32]
        embeds = self.cnn_projection(embeds) # Tensor[batch_size, total_symbols, embed_dim - 2][device][torch.float32]
        embeds = torch.cat((embeds, indices), dim=-1) # Tensor[batch_size, total_symbols, embed_dim][device][torch.float32]
        return embeds

In [None]:
def binary_similarity(left_tensor:torch.Tensor, right_tensor:torch.Tensor) -> torch.Tensor:
    '''
    formula:
        y = (x1 == x2)
    inputs:
        left_tensor: Tensor[B, Dl, E][device][torch.float]
        right_tensor: Tensor[B, Dr, E][device][torch.float]
    outputs:
        output: Tensor[B, Dl, Dr][device][torch.float]
    footnotes:
        B: batch size, i.e, `batch_size`
        Dl: domain size of `left_tensor`
        Dr: domain size of `right_tensor`
        E: extra dimension of the tensors
    '''
    left_tensor = left_tensor.unsqueeze(2) # Tensor[B, Dl, 1, E][device][torch.float32]
    right_tensor = right_tensor.unsqueeze(1) # Tensor[B, 1, Dr, E][device][torch.float32]
    similarity = (left_tensor == right_tensor).all(dim=-1) # Tensor[B, Dl, Dr][device][torch.bool]
    return similarity.to(torch.float32)

def jaccard_similarity(left_tensor:torch.Tensor, right_tensor:torch.Tensor, eps:float=1e-6, exp:int=2) -> torch.Tensor:
    '''
    formula:
        y = ||x1|| ** 2
    inputs:
        left_tensor: Tensor[B, Dl, E][device][torch.float]
        right_tensor: Tensor[B, Dr, E][device][torch.float]
        exp: exponentiation power
        eps: stability parameter
    outputs:
        output: Tensor[B, Dl, Dr][device][torch.float]
    footnotes:
        B: batch size, i.e, `batch_size`
        Dl: domain size of `left_tensor`
        Dr: domain size of `right_tensor`
        E: extra dimension of the tensors
    '''
    left_tensor = left_tensor.unsqueeze(2) # Tensor[B, Dl, 1, E][device][torch.float32]
    right_tensor = right_tensor.unsqueeze(1) # Tensor[B, 1, Dr, E][device][torch.float32]
    left_norm = (left_tensor * left_tensor).sum(dim=-1) # Tensor[B, Dl, 1][device][torch.float32]
    right_norm = (right_tensor * right_tensor).sum(dim=-1) # Tensor[B, 1, Dr][device][torch.float32]
    dot_product = (left_tensor * right_tensor).sum(dim=-1) # Tensor[B, Dl, Dr][device][torch.float32]
    similarity = 2 * dot_product / (left_norm + right_norm + eps) # Tensor[B, Dl, Dr][device][torch.float32]
    similarity = torch.where((left_norm < eps) & (right_norm < eps), torch.ones_like(similarity), similarity) # Tensor[B, Dl, Dr][device][torch.float32]
    return similarity ** exp

def exponential_similarity(left_tensor:torch.Tensor, right_tensor:torch.Tensor, k:int=2) -> torch.Tensor:
    '''
    inputs:
        left_tensor: Tensor[B, Dl, E][device][torch.float]
        right_tensor: Tensor[B, Dr, E][device][torch.float]
        k: norm order
    outputs:
        output: Tensor[B, Dl, Dr][device][torch.float]
    footnotes:
        B: batch size, i.e, `batch_size`
        Dl: domain size of `left_tensor`
        Dr: domain size of `right_tensor`
        E: extra dimension of the tensors
    '''
    left_tensor = left_tensor.unsqueeze(2) # Tensor[B, Dl, 1, E][device][torch.float32]
    right_tensor = right_tensor.unsqueeze(1) # Tensor[B, 1, Dr, E][device][torch.float32]
    k_norm = torch.norm(left_tensor - right_tensor, p=k, dim=-1) # Tensor[B, Dl, Dr][device][torch.float32]
    similarity = torch.exp(-torch.relu(k_norm)) # Tensor[B, Dl, Dr][device][torch.float32]
    return similarity.to(torch.float32)

def row_similarity(x1:torch.Tensor, x2:torch.Tensor) -> torch.Tensor:
    _, _, E = x1.shape
    x1 = x1[:, :, (E - 2,)]
    x2 = x2[:, :, (E - 2,)]
    return binary_similarity(x1, x2)

def col_similarity(x1:torch.Tensor, x2:torch.Tensor) -> torch.Tensor:
    _, _, E = x1.shape
    x1 = x1[:, :, (E - 1,)]
    x2 = x2[:, :, (E - 1,)]
    return binary_similarity(x1, x2)

def block_similarity(x1:torch.Tensor, x2:torch.Tensor) -> torch.Tensor:
    _, D, E = x1.shape
    total_samples_per_side = int(D ** 0.25)
    x1 = x1[:, :, E - 2:] // total_samples_per_side
    x2 = x2[:, :, E - 2:] // total_samples_per_side
    return binary_similarity(x1, x2)

def loc_similarity(x1:torch.Tensor, x2:torch.Tensor) -> torch.Tensor:
    _, _, E = x1.shape
    x1 = x1[:, :, E - 2:]
    x2 = x2[:, :, E - 2:]
    return binary_similarity(x1, x2)

def digit_similarity(x1:torch.Tensor, x2:torch.Tensor) -> torch.Tensor:
    _, _, E = x1.shape
    x1 = x1[:, :, (0,)]
    x2 = x2[:, :, (0,)]
    return binary_similarity(x1, x2)

def vector_similarity(x1:torch.Tensor, x2:torch.Tensor) -> torch.Tensor:
    _, _, E = x1.shape
    x1 = x1[:, :, :E - 2]
    x2 = x2[:, :, :E - 2]
    return jaccard_similarity(x1, x2)

### Test

In [None]:
x1 = torch.tensor([[
    [1, 2, 3.5],
    [0.5, 2, 3.5],
    [0, 0, 0]
]])
j = jaccard_similarity(x1, x1)
e = exponential_similarity(x1, x1, 1)
print(j)
print(e)

tensor([[[1.0000, 0.9852, 0.0000],
         [0.9852, 1.0000, 0.0000],
         [0.0000, 0.0000, 1.0000]]])
tensor([[[1.0000, 0.6065, 0.0015],
         [0.6065, 1.0000, 0.0025],
         [0.0015, 0.0025, 1.0000]]])


## Operation Tools

In [None]:
class Logger():

    def __init__(self,
                 dataset_name:str,
                 device:str,
                 max_iterations:int,
                 vlm_load:int,
                 termination_auc:float,
                 learning_rate:float,
                 batch_size:int,
                 tot_epochs:int,
                 symbol_size:int,
                 cnn_hidden_dims:tuple[int, ...],
                 embed_dim:int,
                 drop_prob:float,
                 observation_delay:int,
                 observation_patience:int,
                 ) -> None:
        self.dataset_name = dataset_name
        self.device = device
        self.max_iterations = max_iterations
        self.vlm_load = vlm_load
        self.termination_auc = termination_auc
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.tot_epochs = tot_epochs
        self.symbol_size = symbol_size
        self.cnn_hidden_dims = cnn_hidden_dims
        self.embed_dim = embed_dim
        self.drop_prob = drop_prob
        self.observation_delay = observation_delay
        self.observation_patience = observation_patience
        self.name = dataset_name + ' ' + datetime.now().strftime('%Y-%m-%d %H-%M-%S')
        self.json_path = os.path.join(DRIVE_LOGS_DIR, self.name + '.json')
        self.csv_path = os.path.join(DRIVE_LOGS_DIR, self.name + '.csv')
        with open(self.json_path, 'w') as json_file:
            json.dump({
                'dataset_name': dataset_name,
                'device': device,
                'max_iterations': max_iterations,
                'vlm_load': vlm_load,
                'termination_auc': termination_auc,
                'learning_rate': learning_rate,
                'batch_size': batch_size,
                'tot_epochs': tot_epochs,
                'symbol_size': symbol_size,
                'cnn_hidden_dims': cnn_hidden_dims,
                'embed_dim': embed_dim,
                'drop_prob': drop_prob,
                'observation_delay': observation_delay,
                'observation_patience': observation_patience
            }, json_file, indent=2)
        pd.DataFrame(columns=[
            'trial',
            'system_role',
            'response',
            'extracted_fol_rule',
            'exception_message',
            'stop_epoch',
            'best_epoch',
            'train_loss',
            'val_loss',
            'test_loss',
            'train_auc',
            'val_auc',
            'test_auc',
            'vlm_delay',
            'dltn_delay',
            'dev_delay',
            'eval_delay'
        ], dtype=object).to_csv(self.csv_path, index=False)

    def dump(self, key:str, param:Any, do_print:bool) -> None:
        log_df = pd.read_csv(self.csv_path, dtype=object)
        if key == 'trial':
            index = log_df.shape[0]
            prefix = ''
        else:
            index = log_df.shape[0] - 1
            prefix = '  '
        log_df.loc[index, key] = param
        if do_print:
            print(f'{prefix}{key} -> {param}')
        log_df.to_csv(self.csv_path, index=False)


def symbols_to_images(vlm_indices:torch.Tensor, dimension:int, symbol_size:int, symbols:torch.Tensor, labels:torch.Tensor) -> tuple[list[Image.Image], list[str]]:

    n_blocks = int(dimension ** 0.5)
    padding = int(0.2 * symbol_size)
    margin = 20

    padded_symbol_size = symbol_size + 2 * padding
    images = torch.zeros((vlm_indices.shape[0], dimension * padded_symbol_size, dimension * padded_symbol_size), dtype=torch.uint8)
    for n, vlm_index in enumerate(vlm_indices):
        for i in range(dimension):
            for j in range(dimension):
                images[n, i * symbol_size + (2 * i + 1) * padding:(i + 1) * symbol_size + (2 * i + 1) * padding, j * symbol_size + (2 * j + 1) * padding:(j + 1) * symbol_size + (2 * j + 1) * padding] = symbols[vlm_index, i, j, :, :]
        for l in range(n_blocks + 1):
            images[n, min(dimension * padded_symbol_size - 1,  l * padded_symbol_size * n_blocks), :] = 255
            images[n, :, min(dimension * padded_symbol_size - 1,  l * padded_symbol_size * n_blocks)] = 255

    pil_image_list = list[Image.Image]()
    base64_image_list = list[str]()
    for n, vlm_index in enumerate(vlm_indices):
        raw_pil_image = Image.fromarray(images[n].numpy(), mode='L')
        pil_image = Image.new('L', (raw_pil_image.width, raw_pil_image.height + margin), 0)
        pil_image.paste(raw_pil_image, (0, 0))
        text = f'Label: {train_labels[vlm_index].item()}'
        draw = ImageDraw.Draw(pil_image)
        font = ImageFont.load_default(12)
        bbox = draw.textbbox((0, 0), text, font=font)
        text_width = bbox[2] - bbox[0]
        text_height = bbox[3] - bbox[1]
        text_x = (pil_image.width - text_width) // 2
        text_y = raw_pil_image.height + (margin - text_height) // 2
        draw.text((text_x, text_y), text, fill=255, font=font)
        buffer = io.BytesIO()
        pil_image.save(buffer, format='PNG')
        buffer.seek(0)
        pil_image_list.append(pil_image)
        base64_image_list.append(base64.b64encode(buffer.read()).decode('utf-8'))

    return pil_image_list, base64_image_list

## Discovery Loop

In [None]:
ltn_builder = LTNBuilder()
client = Groq(api_key=GROQ_API_KEY)

In [None]:
## Logger
logger = Logger(
    dataset_name = 'fmnist_4x4_split_11',
    device = 'cuda' if torch.cuda.is_available() else 'cpu',
    max_iterations = 20,
    vlm_load = 3,
    termination_auc = 0.95,
    learning_rate = 0.001,
    batch_size = 64,
    tot_epochs = 400,
    symbol_size = 28,
    cnn_hidden_dims = (16, 32, 64),
    embed_dim = 128,
    drop_prob = 0.1,
    observation_delay = 100,
    observation_patience = 50
)

## Dataset Handling
dataset_path = os.path.join(DATASETS_DIR, f'{logger.dataset_name}.pt')
data_dict:dict[str, dict[str, torch.Tensor]] = torch.load(dataset_path)

train_dict = data_dict['train']
train_symbols = train_dict['symbols']
train_digits = train_dict['digits']
train_labels = train_dict['labels']

val_dict = data_dict['val']
val_symbols = val_dict['symbols']
val_digits = val_dict['digits']
val_labels = val_dict['labels']

test_dict = data_dict['test']
test_symbols = test_dict['symbols']
test_digits = test_dict['digits']
test_labels = test_dict['labels']

n_images, dimension, _, symbol_size, _ = train_symbols.shape

## Initialization
observation_counter = 0
best_model_epoch = -1
best_train_loss = float('inf')
best_train_auc = 0.0
best_val_loss = float('inf')
best_val_auc = 0.0
best_model_state = torch.nn.Module()
history_list = list[tuple[str, str, float]]()

# Loop
for trial in range(logger.max_iterations):

    ## Trial
    logger.dump('trial', trial, True)

    # VLM
    vlm_start_time = time.time()
    system_role = (
        'You are a helpful assistant that can extract the First-Order Logic (FOL) rule from images.'
        '\nTHE GRAMMAR OF FOL:'
        '\n- Constants: Not allowed in the rule.'
        '\n- Variables: Your options are `x1`, `x2`, ..., which represent visual objects.'
        '\n- Functions: Not allowed in the rule.'
        '\n- Predicates: Your options are `P_same_row`, `P_same_col`, `P_same_block`, `P_same_loc`, and `P_same_value`.'
        '\n- To compare variables, only use predicates.'
        '\n- The symbols used for logical AND, OR, and NOT are respectively `&`, `|`, and `!`'
        '\n- The symbols used for implication and equivalence are respectively `implies` and `iff`.'
        '\n- The symbols used for universal and existential quantifiers are respectively `forall` and `exists`.'
        '\n- Use parentheses for preserving operation precedence.'
        '\nWHAT YOU MUST CONSIDER:'
        '\n- Use your own knowledge to analyze and deeply think about the images provided as your reference.'
        '\n- All the images must follow the same rule that you extract.'
        '\n- The rule applies to the visual objects within each image.'
        '\n- The visual objects may represent numbers rather than what they really are.'
        '\n- At the end of your chain of thought, put the extracted rule in the following template:'
        '\n  EXTRACTED_RULE: "the rule you extracted"'
    )
    if len(history_list) > 0:
        system_role += '\nHISTORY OF PREVIOUS TRIALS:'
        for old_trial, (error_message, extracted_fol_rule, avg_test_auc) in enumerate(history_list):
            system_role += f'\n- Trial {old_trial+1} -> '
            if error_message != '':
                system_role += f'error: "{error_message}"'
            else:
                system_role += f'extracted rule: "{extracted_fol_rule}", average auc: {avg_test_auc}'
        if avg_test_auc < logger.termination_auc:
            system_role += '\nIMPORTANT LESSON FROM HISTORY:'
            if all(extracted_fol_rule == '' for _, extracted_fol_rule, _ in history_list):
                system_role += '- Pay attention to the the instructions!'
            else:
                system_role += '- The next FOL rule must be an improved version of the above!'
    logger.dump('system_role', system_role, False)

    user_role = [{'type': 'text', 'text': 'These are the reference images:'}]
    vlm_indices = torch.randperm(n_images)[:logger.vlm_load]
    _, base64_image_list = symbols_to_images(vlm_indices, dimension, symbol_size, train_symbols, train_labels)
    for base64_image in base64_image_list:
        user_role.append({'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}'}})

    chat_completion = client.chat.completions.create(
        messages=[
            {'role': 'system', 'content': system_role},
            {'role': 'user', 'content': user_role}
        ],
        model='meta-llama/llama-4-maverick-17b-128e-instruct',
    )
    response = chat_completion.choices[0].message.content
    vlm_end_time = time.time()
    logger.dump('response', response, False)

    # D-LTN
    dltn_start_time = time.time()
    try:
        extracted_fol_rule:str = re.search(r'EXTRACTED_RULE\s*:\s*"([^"]*)"', response).group(1)
        logger.dump('extracted_fol_rule', extracted_fol_rule, True)
    except:
        exception_message = 'No FOL rule could not be extracted from your response!'
        history_list.append((exception_message, '', 0.0))
        logger.dump('exception_message', exception_message, True)
        continue

    try:
        ltn_formula = ltn_builder.make_formula(extracted_fol_rule)
    except:
        exception_message = 'No FOL rule could not be parsed from your response!'
        history_list.append((exception_message, '', 0.0))
        logger.dump('exception_message', exception_message, True)
        continue

    try:
        ltn_formula.ground(
            P_same_row=row_similarity,
            P_same_col=col_similarity,
            P_same_block=block_similarity,
            P_same_loc=loc_similarity,
            P_same_value=vector_similarity,
            logical_and='goguen',
            logical_or='goguen',
            implies='reichenbach',
            forall='power_mean',
            exists='power_mean'
        )
        ltn_formula(
            x1=torch.rand(1, 4, 4, device=logger.device),
            x2=torch.rand(1, 4, 4, device=logger.device)
        )
    except:
        exception_message = 'There were problems with your groundings!'
        history_list.append((exception_message, '', 0.0))
        logger.dump('exception_message', exception_message, True)
        continue
    dltn_end_time = time.time()

    ## Visual Encoder
    visual_encoder = VisualEncoder(
        device = logger.device,
        symbol_size = logger.symbol_size,
        cnn_hidden_dims = logger.cnn_hidden_dims,
        embed_dim = logger.embed_dim,
        drop_prob = logger.drop_prob
    )
    optimizer = torch.optim.Adam(visual_encoder.parameters(), lr=logger.learning_rate)

    ## Development
    dev_start_time = time.time()
    for epoch in range(logger.tot_epochs):

        if epoch < 20:
            power_mean_p = 1
        elif epoch < 120:
            power_mean_p = 2
        elif epoch < 140:
            power_mean_p = 4
        elif epoch < 170:
            power_mean_p = 6
        elif epoch < 200:
            power_mean_p = 8
        else:
            power_mean_p = 10
        ltn_formula.control(
            forall = {'p': power_mean_p}
        )

        visual_encoder.train()
        train_loss_list = list[float]()
        train_auc_list = list[float]()
        for step in range(train_symbols.shape[0] // logger.batch_size):
            x_train = visual_encoder(train_symbols[step * logger.batch_size:(step + 1) * logger.batch_size])
            predictions = ltn_formula(
                x1=x_train,
                x2=x_train
            )
            targets = train_labels[step * logger.batch_size:(step + 1) * logger.batch_size].to(logger.device).to(torch.float32)
            loss = torch.nn.functional.binary_cross_entropy(predictions, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            try:
                auc = roc_auc_score(targets.detach().cpu().numpy(), predictions.detach().cpu().numpy())
            except ValueError:
                auc = 0.5
            train_loss_list.append(loss.item())
            train_auc_list.append(auc)
        avg_train_loss = sum(train_loss_list) / len(train_loss_list)
        avg_train_auc = sum(train_auc_list) / len(train_auc_list)

        visual_encoder.eval()
        with torch.no_grad():
            x_val = visual_encoder(val_symbols)
            predictions = ltn_formula(
                x1=x_val,
                x2=x_val
            )
            targets = val_labels.to(logger.device).to(torch.float32)
            loss = torch.nn.functional.binary_cross_entropy(predictions, targets)
            try:
                auc = roc_auc_score(targets.detach().cpu().numpy(), predictions.detach().cpu().numpy())
            except ValueError:
                auc = 0.5
            avg_val_loss = loss.item()
            avg_val_auc = auc

        if epoch >= logger.observation_delay:
            if avg_val_loss < best_val_loss:
                best_model_epoch = epoch
                best_train_loss = avg_train_loss
                best_train_auc = avg_train_auc
                best_val_loss = avg_val_loss
                best_val_auc = avg_val_auc
                best_model_state = visual_encoder.state_dict()
                observation_counter = 0
            else:
                observation_counter += 1
            if observation_counter >= logger.observation_patience:
                break
    dev_end_time = time.time()
    logger.dump('stop_epoch', epoch, True)
    logger.dump('best_epoch', best_model_epoch, True)
    logger.dump('train_loss', avg_train_loss, True)
    logger.dump('val_loss', avg_val_loss, True)
    logger.dump('train_auc', avg_train_auc, True)
    logger.dump('val_auc', avg_val_auc, True)

    ## Evaluation
    visual_encoder.load_state_dict(best_model_state)
    visual_encoder.eval()
    with torch.no_grad():
        x_test = visual_encoder(test_symbols)
        predictions = ltn_formula(
            x1=x_test,
            x2=x_test
        )
        targets = test_labels.to(logger.device).to(torch.float32)
        loss = torch.nn.functional.binary_cross_entropy(predictions, targets)
        try:
            auc = roc_auc_score(targets.detach().cpu().numpy(), predictions.detach().cpu().numpy())
        except ValueError:
            auc = 0.5
        avg_test_loss = loss.item()
        avg_test_auc = auc
    eval_end_time = time.time()
    logger.dump('test_loss', avg_test_loss, True)
    logger.dump('test_auc', avg_test_auc, True)

    ## Delays
    vlm_delay = vlm_end_time - vlm_start_time
    dltn_delay = dltn_end_time - dltn_start_time
    dev_delay = dev_end_time - dev_start_time
    eval_delay = eval_end_time - dev_end_time
    logger.dump('vlm_delay', vlm_delay, True)
    logger.dump('dltn_delay', dltn_delay, True)
    logger.dump('dev_delay', dev_delay, True)
    logger.dump('eval_delay', eval_delay, True)

    ## Loop Completion
    history_list.append(('', extracted_fol_rule, avg_test_auc))
    if avg_test_auc >= logger.termination_auc:
        break

# Tests

## Symbols to Image

In [None]:
data_dict:dict[str, dict[str, torch.Tensor]] = torch.load(os.path.join(DATASETS_DIR, 'mnist_4x4_split_11.pt'))

train_dict = data_dict['train']
train_symbols = train_dict['symbols']
train_digits = train_dict['digits']
train_labels = train_dict['labels']

n_images, dimension, _, symbol_size, _ = train_symbols.shape
n_blocks = int(dimension ** 0.5)

vlm_indices = torch.randperm(n_images)[:5]
pil_image_list, base64_image_list = symbols_to_images(vlm_indices, dimension, symbol_size, train_symbols, train_labels)

n_rows = torch.floor(torch.tensor(vlm_indices.shape[0] ** 0.5)).int().item()
n_cols = torch.ceil(torch.tensor(vlm_indices.shape[0] / n_rows)).int().item()
fig, axs = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(n_cols * 2.5, n_rows * 2.5))
for k, ax in enumerate(axs.flatten()):
    ax.axis('off')
    if k < vlm_indices.shape[0]:
        ax.imshow(pil_image_list[k], cmap='gray')
        ax.set_title(f'Img: {k}, Idx: {vlm_indices[k]}', fontsize=10)
fig.tight_layout()
fig.show()

for n, vlm_index in enumerate(vlm_indices):
    print(f'Img: {k}, Idx: {vlm_index}, Lbl: {train_labels[vlm_index].item()}')
    print(test_digits[n, :, :])

NameError: name 'symbols_to_images' is not defined

## Rule Test

In [None]:
# ltn_formula = LTNBuilder().make_formula('forall x1, x2 !P_same_loc(x1, x2) & (P_same_row(x1, x2) | P_same_col(x1, x2) | P_same_block(x1, x2)) implies !P_same_value(x1, x2)') #mnist
ltn_formula = LTNBuilder().make_formula('forall x1 forall x2 (P_same_row(x1, x2) | P_same_col(x1, x2) | P_same_block(x1, x2)) & !P_same_loc(x1, x2) implies !P_same_value(x1, x2)') #emnist
# ltn_formula = LTNBuilder().make_formula('forall x1, x2 (!P_same_loc(x1, x2) & P_same_value(x1, x2)) implies (!P_same_row(x1, x2) & !P_same_col(x1, x2) & !P_same_block(x1, x2))') #kmnist
print(ltn_formula.get('parsed'))

dataset_name = 'emnist_4x4_split_11'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
termination_auc = 0.95
learning_rate = 0.001
batch_size = 64
tot_epochs = 400
symbol_size = 28
cnn_hidden_dims = (16, 32, 64)
embed_dim = 128
drop_prob = 0.1
observation_delay = 100
observation_patience = 50

dataset_path = os.path.join(DATASETS_DIR, f'{dataset_name}.pt')
data_dict:dict[str, dict[str, torch.Tensor]] = torch.load(dataset_path)

train_dict = data_dict['train']
train_symbols = train_dict['symbols']
train_digits = train_dict['digits']
train_labels = train_dict['labels']

val_dict = data_dict['val']
val_symbols = val_dict['symbols']
val_digits = val_dict['digits']
val_labels = val_dict['labels']

test_dict = data_dict['test']
test_symbols = test_dict['symbols']
test_digits = test_dict['digits']
test_labels = test_dict['labels']

ltn_formula.ground(
    P_same_row=row_similarity,
    P_same_col=col_similarity,
    P_same_block=block_similarity,
    P_same_loc=loc_similarity,
    P_same_value=vector_similarity,
    logical_and='goguen',
    logical_or='goguen',
    implies='reichenbach',
    forall='power_mean',
    exists='power_mean'
)
print(ltn_formula.get('info'))

visual_encoder = VisualEncoder(
    device = device,
    symbol_size = symbol_size,
    cnn_hidden_dims = cnn_hidden_dims,
    embed_dim = embed_dim,
    drop_prob = drop_prob
)
optimizer = torch.optim.Adam(visual_encoder.parameters(), lr=learning_rate)

observation_counter = 0
best_model_epoch = -1
best_train_loss = float('inf')
best_train_auc = 0.0
best_val_loss = float('inf')
best_val_auc = 0.0
best_model_state = torch.nn.Module()
history_list = list[tuple[str, str, float]]()

for epoch in range(tot_epochs):
    pbar = tqdm(total=train_symbols.shape[0] // batch_size, desc=f'Epoch {epoch}/{tot_epochs}')

    if epoch < 20:
        power_mean_p = 1
    elif epoch < 120:
        power_mean_p = 2
    elif epoch < 140:
        power_mean_p = 4
    elif epoch < 170:
        power_mean_p = 6
    elif epoch < 200:
        power_mean_p = 8
    else:
        power_mean_p = 10
    ltn_formula.control(
        forall = {'p': power_mean_p}
    )

    visual_encoder.train()
    train_loss_list = list[float]()
    train_auc_list = list[float]()
    for step in range(pbar.total):
        x_train = visual_encoder(train_symbols[step * batch_size:(step + 1) * batch_size])
        predictions = ltn_formula(
            x1=x_train,
            x2=x_train
        )
        targets = train_labels[step * batch_size:(step + 1) * batch_size].to(device).to(torch.float32)
        loss = torch.nn.functional.binary_cross_entropy(predictions, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        try:
            auc = roc_auc_score(targets.detach().cpu().numpy(), predictions.detach().cpu().numpy())
        except ValueError:
            auc = 0.5
        train_loss_list.append(loss.item())
        train_auc_list.append(auc)
        avg_train_loss = sum(train_loss_list) / len(train_loss_list)
        avg_train_auc = sum(train_auc_list) / len(train_auc_list)
        pbar.set_postfix_str(f'Train (Loss: {avg_train_loss:.4f}, AUC: {avg_train_auc:.4f})')
        pbar.update()

    visual_encoder.eval()
    with torch.no_grad():
        x_val = visual_encoder(val_symbols)
        predictions = ltn_formula(
            x1=x_val,
            x2=x_val
        )
        targets = val_labels.to(device).to(torch.float32)
        loss = torch.nn.functional.binary_cross_entropy(predictions, targets)
        accuracy = ((predictions >= 0.5).to(device).to(torch.float32) == targets).to(torch.float32).mean()
        try:
            auc = roc_auc_score(targets.detach().cpu().numpy(), predictions.detach().cpu().numpy())
        except ValueError:
            auc = 0.5
        avg_val_loss = loss.item()
        avg_val_auc = auc
        pbar.set_postfix_str(f'Train (Loss: {avg_train_loss:.4f}, AUC: {avg_train_auc:.4f}), Val (Loss: {avg_val_loss:.4f}, AUC: {avg_val_auc:.4f})')
    pbar.close()

    if epoch >= observation_delay:
        if avg_val_loss < best_val_loss:
            best_model_epoch = epoch
            best_train_loss = avg_train_loss
            best_train_auc = avg_train_auc
            best_val_loss = avg_val_loss
            best_val_auc = avg_val_auc
            best_model_state = visual_encoder.state_dict()
            observation_counter = 0
        else:
            observation_counter += 1
        if observation_counter >= observation_patience:
            break

visual_encoder.load_state_dict(best_model_state)
visual_encoder.eval()
with torch.no_grad():
    x_test = visual_encoder(test_symbols)
    predictions = ltn_formula(
        x1=x_test,
        x2=x_test
    )
    targets = test_labels.to(device).to(torch.float32)
    loss = torch.nn.functional.binary_cross_entropy(predictions, targets)
    try:
        auc = roc_auc_score(targets.detach().cpu().numpy(), predictions.detach().cpu().numpy())
    except ValueError:
        auc = 0.5

print('stop_epoch', epoch)
print('best_epoch', best_model_epoch)
print('train_loss', avg_train_loss)
print('val_loss', avg_val_loss)
print('train_auc', avg_train_auc)
print('val_auc', avg_val_auc)
print('test_loss', loss.item())
print('test_auc', auc)

forall[x1](forall[x2]((P_same_row(x1, x2) | P_same_col(x1, x2) | P_same_block(x1, x2)) & !(P_same_loc(x1, x2)) implies !(P_same_value(x1, x2))))
Ungrounded Symbols:
  x1 -> variable
  x2 -> variable
Grounded Symbols:
  P_same_block -> block_similarity[n/a]
  P_same_col -> col_similarity[n/a]
  P_same_loc -> loc_similarity[n/a]
  P_same_row -> row_similarity[n/a]
  P_same_value -> vector_similarity[n/a]
  logical_not -> complementation[1 - x]*, 
  wrapper -> identity[x]*, 
  implies -> reichenbach[1 - x1 + x1 * x2]*, lukasiewicz[min(1, 1 - x1 + x2)], godel[1 if x1 <= x2 else x2], goguen[1 if x1 <=x2 else x2 / x1], kleene_dienes[max(1 - x1, x2)]
  logical_and -> goguen[x1 * x2]*, lukasiewicz[max(0, x1 + x2 - 1)], godel[min(x1, x2)]
  logical_or -> goguen[1 - (1 - x1) * (1 - x2)]*, lukasiewicz[min(1, x1 + x2)], godel[max(x1, x2)]
  forall -> power_mean[1 - mean((1 - xi) ** p) ** (1 / p), 1 hyper-arg, 0 empty]*, lukasiewicz[max(0, sum(xi) - n + 1)], godel[min(xi)], goguen[prod(xi)]
Footnot

Epoch 0/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 1/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 2/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 3/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 4/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 5/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 6/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 7/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 8/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 9/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 10/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 11/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 12/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 13/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 14/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 15/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 16/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 17/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 18/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 19/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 20/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 21/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 22/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 23/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 24/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 25/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 26/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 27/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 28/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 29/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 30/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 31/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 32/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 33/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 34/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 35/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 36/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 37/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 38/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 39/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 40/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 41/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 42/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 43/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 44/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 45/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 46/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 47/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 48/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 49/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 50/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 51/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 52/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 53/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 54/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 55/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 56/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 57/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 58/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 59/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 60/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 61/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 62/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 63/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 64/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 65/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 66/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 67/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 68/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 69/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 70/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 71/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 72/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 73/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 74/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 75/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 76/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 77/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 78/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 79/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 80/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 81/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 82/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 83/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 84/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 85/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 86/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 87/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 88/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 89/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 90/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 91/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 92/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 93/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 94/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 95/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 96/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 97/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 98/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 99/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 100/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 101/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 102/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 103/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 104/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 105/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 106/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 107/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 108/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 109/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 110/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 111/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 112/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 113/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 114/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 115/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 116/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 117/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 118/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 119/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 120/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 121/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 122/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 123/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 124/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 125/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 126/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 127/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 128/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 129/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 130/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 131/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 132/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 133/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 134/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 135/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 136/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 137/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 138/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 139/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 140/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 141/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 142/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 143/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 144/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 145/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 146/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 147/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 148/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 149/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 150/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 151/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 152/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 153/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 154/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 155/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 156/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 157/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 158/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 159/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 160/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 161/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 162/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 163/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 164/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 165/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 166/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 167/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 168/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 169/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 170/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 171/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 172/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 173/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 174/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 175/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 176/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 177/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 178/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 179/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 180/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 181/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 182/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 183/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 184/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 185/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 186/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 187/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 188/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 189/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 190/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 191/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 192/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 193/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 194/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 195/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 196/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 197/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 198/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 199/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 200/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 201/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 202/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 203/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 204/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 205/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 206/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 207/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 208/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 209/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 210/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 211/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 212/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 213/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 214/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 215/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 216/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 217/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 218/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 219/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 220/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 221/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 222/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 223/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 224/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 225/400:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 226/400:   0%|          | 0/3 [00:00<?, ?it/s]

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
