#  TAYSIR Baseline for Track 2- Extraction of WA from Recurrent Neural Net already Trained on a Language Modelling Task

### Welcome!

This is a notebook to let you play around with the Weighted Automata extraction baseline that use the spectral extraction technique.
Inputed Neural Net can be LSTM, GRU or SRN network, after which it will draw a neat little WA.

# Requirements
## Imports and version verifying 

In [None]:
%pip install -q mlflow torch

In [None]:
import torch
import mlflow

print("Your torch version:", torch.__version__)
print("Your mlflow version:", mlflow.__version__)
import sys
print("Your python version:", sys.version)

This notebook was tested with:
* Torch version: 1.11.0+cu102
* MLFlow version: 1.25.1
* Python version: 3.8.10 [GCC 9.4.0]

Python versions starting at 3.7 are supposed to work (but have not been tested).

## Choosing the task

First you must select one of the phases/datasets we provide

In [None]:
TRACK = 2
DATASET = 0

## Loading the RNN of the competition

In [None]:
model_name = f"models/{TRACK}.{DATASET}.taysir.model"

model = mlflow.pytorch.load_model(model_name)
model.eval()

### Initialisation of some variables that would be useful

In [None]:
nb_letters = model.input_size -1
cell_type = model.cell_type

print("The alphabet contains", nb_letters, "symbols.")
print("The type of the recurrent cells is", cell_type.__name__)

## Load the data

The input data is in the following format :

```
[Number of sequences] [Alphabet size]
[Length of sequence] [List of symbols]
[Length of sequence] [List of symbols]
[Length of sequence] [List of symbols]
...
[Length of sequence] [List of symbols]
```

For example the following data :

```
5 10
6 8 6 5 1 6 7 4 9
12 8 6 9 4 6 8 2 1 0 6 5 9
7 8 9 4 3 0 4 9
4 8 0 4 9
8 8 1 5 2 6 0 5 3 9
```

is composed of 5 sequences and have an alphabet size of 10 (so symbols are between 0 and 9) and the first sequence is composed of 6 symbols (8 6 5 1 6 7 4 9), notice that 8 is the start symbol and 9 is the end symbol.

In [None]:
file = f"datasets/{TRACK}.{DATASET}.taysir.valid.words"

sequences = []
with open(file) as f:
    f.readline() #Skip first line (number of sequences, alphabet size)
    for line in f:
        line = line.strip()
        seq = line.split(' ')
        seq = [int(i) for i in seq[1:]] #Remove first value (length of sequence) and cast to int
        sequences.append(seq)

The variable *sequences* is thus **a list of lists**.

In [None]:
print('Number of sequences:', len(sequences))
print('10 first sequences:')
for i in range(10):
    print(sequences[i])

We then load the trained RNN. It is given as a MLFlow model.

# Model extraction
## Seeding
We are seeding for reproductibility:

In [None]:
import random
import numpy as np
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)

We then load the trained RNN. It is given as a MLFlow model.

## WA Extraction Baseline
This part is the one you need to change to put your own algorithm.

Our algorithm is the spectral extraction described in this paper: https://arxiv.org/abs/2009.13101

We are going to fix a number of prefixes (and of suffixes) and then use the model to generate these numbers of elements. The model will then be used to fill the Hankel matrix from which we will create a Weighted Automaton. 

This first function could be useful for other approaches:

In [None]:
def generate_one_word(model, max_len, nb_letters):
    """ A function that uses the LM-RNN to generate a sequence"""

    current_symbol = nb_letters - 2 #start symbol is always that integer
    gen_word = [current_symbol]
    len_word = 1
    current_hidden = None #Initial state is defined that way
    with torch.no_grad():
        while len_word < max_len and current_symbol != nb_letters - 1: #end symbol is always nb_letters - 1
            current_one_encoded = model.one_hot_encode([current_symbol])
            """Despite its name, in LM task, this is the function that provides the probability of the next symbol given a prefix"""
            out, current_hidden = model.forward_bin(current_one_encoded, current_hidden)

            """Sample next letter acording to RNN next symbol distribution"""
            current_symbol = torch.multinomial(out, 1).item()
            gen_word += [current_symbol]
            len_word +=1
    
    """Make sure the last symbol is the end of sequence one"""
    if len_word == max_len and current_symbol!=nb_letters - 1:
        gen_word +=[nb_letters - 1]
    
    return gen_word


 Our algorithm relies on the use of the toolbox scikit-splearn (https://remieyraud.github.io/scikit-splearn/) that can be installed using:

In [None]:
!pip3 install scikit-splearn

These functions are needed for our algorithm:

In [None]:
def add_all_prefixes(prefixes_set, word):
    """ add all prefixes of a word to an existing set of prefixes"""
    for i in range(2, len(word)):
        prefixes_set.add(tuple(word[:i]))
def add_all_suffixes(suffixes_set, word):
    """ add all suffixes of a word to an existing set of prefixes"""
    for i in range(2, len(word)):
        suffixes_set.add(tuple(word[i:]))

def generate_basis(model, nb_prefixes, nb_suffixes, max_len, nb_letters):
    """A function to generate a set of prefixes and suffixes to be fed to the model to build the Hankel matrix"""
    words = set()
    prefixes = set()
    suffixes = set()
    with torch.no_grad():
        while (len(prefixes) < nb_prefixes or len(suffixes) < nb_suffixes):
            gen_word = tuple(generate_one_word(model, max_len, nb_letters))
            if gen_word not in words:
                words.add(gen_word)
                if len(prefixes) < nb_prefixes:
                    add_all_prefixes(prefixes, list(gen_word))
                if len(suffixes) < nb_suffixes:
                    add_all_suffixes(suffixes, list(gen_word))
                        
    # it is better to sort for the Hankel construction and if the lists start by the delimiting symbols
    rows = [(nb_letters-2,)] + sorted(list(prefixes), key=lambda t: (len(t), t[0]))
    columns = [(nb_letters-1,)] + sorted(list(suffixes), key=lambda t: (len(t), t[0]))
    
    # need to create the set of all the words this basis implies to ask the RNN
    letters = [[]] + [[i] for i in range(nb_letters)]
    all_combinations= set()
    for letter in letters:
            for prefix in rows:
                for suffix in columns:
                    all_combinations.add(tuple(list(prefix) + letter + list(suffix)))
    return rows, columns, list(all_combinations)

def get_values(model, all_combinations, nb_letters):
    """ returns a dictionary with all words in all_combinations as keys and corresponding RNN assigned values"""
    probas = dict()
    for word in all_combinations:
        one_hot_word = model.one_hot_encode(list(word))
        value = model.predict(one_hot_word)
        probas[tuple(word)] = value

    return probas

def create_hankels(model, prefixes, suffixes, all_combinations, nb_letters):
    """
    Redefinition of hankels(): return the list of matrices needed for extracting WA
    :param model: a RNN in pytorch
    :param prefixes: the list of prefixes (for rows)
    :param suffixes: the list of suffixes (for columns)
    :param all_combinaisons: a list of all the words whose value have to be asked to the RNN
    :param nb_letters: the number of letters of the problem
    
    :return: a list of matrices lhankels. lhankels[0] is the Hankel matrice while
             lhankel[i] is H_{i-1}: lhankel[i][prefix][suffix]=predict(prefix + [i] + suffix)
    """
    print("Computing Hankels...")
    words_probas = get_values(model, all_combinations, nb_letters)
    print("    Done using the model")
    
    lhankels = [np.zeros((len(prefixes), len(suffixes))) for _ in range(nb_letters+1)]
    # empty string and letters matrices:
    letters = [[]] + [[i] for i in range(nb_letters)]
    for letter in range(len(letters)):
        for l in range(len(prefixes)):
            for c in range(len(suffixes)):
                p = words_probas[prefixes[l] + tuple(letters[letter]) + suffixes[c]]
                lhankels[letter][l][c] = p
    print("    Done computing Hankels")
    return lhankels

And now we can define our baseline.

In [None]:
import splearn as sp
from numpy.linalg import svd, pinv
def spectral_distillation(model, nb_states, nb_prefixes, nb_suffixes, max_len, nb_letters):
    """
        Extract a WA of given rank from a RNN
        
        :param model: a pytorch recurrent model
        :param nb_states: the rank of the WA to be extracted
        :param nb_prefixes: number of prefixes for the basis 
        :param nb_suffixes: number of suffixes for the basis 
        :param max_len: the maximal size of the sequences generated 
        :param nb_letters: the number of different symbols 
        
        :results: return the distiled weighted automata  
    """
    prefixes, suffixes, all_combinaisons = generate_basis(model, nb_prefixes, nb_suffixes, max_len, nb_letters)
    hankels = create_hankels(model, prefixes, suffixes, all_combinaisons, nb_letters)
    """Computing the SVD"""
    hankel = hankels[0]
    [u, s, v] = svd(hankel)
    
    u = u[:, :nb_states]
    v = v[:nb_states, :]
    ds = np.diag(s[:nb_states])
    
    #Computing WA elements
    pis = pinv(v)
    del v
    pip = pinv(np.dot(u, ds))
    del u, ds
    init = np.dot(hankel[0, :], pis)
    term = np.dot(pip, hankel[:, 0])
    transitions = []
    for x in range(nb_letters):
        hankel = hankels[x+1]
        transitions.append(np.dot(pip, np.dot(hankel, pis)))
    
    WA = sp.Automaton(nbL=nb_letters, nbS=nb_states, initial=init, final=term, transitions=transitions, type="classic")
    return WA
    

In [None]:
WA = spectral_distillation(model, 2, 2, 2, 10, nb_letters)
print("Number of states of the extracted WA:", WA.initial.shape[0])
print("Output on example:", WA.val(sequences[42]))

# Submission

The only thing to do is to define a function that takes a sequence as a list of integers and returns the value given to this sequence to the sequence. Your model is **NOT** a parameter of this function.  

In [None]:
def predict(seq):
    return WA.val(seq)

## Save and submit 
This is the creation of the model needed for the submission to the competition: you just have to run this cell. It will create in your current directory an **archive**  that you can then submit on the competition website.

**You should NOT modify this part, just run it**

In [None]:
from submit_tools import save_function

save_function(predict, alphabet_size=nb_letters, prefix=f'dataset{TRACK}.{DATASET}')

# For fun, show WA graphical representation
You may need to install the graphviz library.

In [None]:
dot = WA.get_dot(threshold = 0.01, title = 'dotfile')
# To display the dot string, one can use graphviz:
from graphviz import Source
src = Source(dot)
src.render('dotfile' + '.gv', view=True)