## Sparse Tsetlin Machine training on IMDB

This notebook gives an example for using the Sparse Tsetlin Machine from green-tsetlin on the **IMDB sentiment dataset**.

In [4]:
import numpy as np

seed = 42
rng = np.random.default_rng(seed)

### Sklearn CountVectorizer

With sklearn CountVectorizer, we can transform the data into bag-of-words.

E.g the input text "I love swimming in the ocean" is transformed to : [0, 1, 1, 1, 0, 0] \
This vector is based on the vocabulary of the CountVectorizer, e.g ["dogs", "love", "ocean", "swimming", "biking", "movie"] \
We obtain the vocabulary by fitting the data. This gives us words / tokens that occur in the data.

In [5]:
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.model_selection import train_test_split
import datasets

imdb = datasets.load_dataset('imdb')
x, y = imdb['train']['text'], imdb['train']['label']

vectorizer = CountVectorizer(ngram_range=(1, 3), binary=True, lowercase=True, max_features=30000)
vectorizer.fit(x)

x_bin = vectorizer.transform(x).toarray().astype(np.uint8)
y = np.array(y).astype(np.uint32)

shuffle_index = [i for i in range(len(x))]
rng.shuffle(shuffle_index)

x_bin = x_bin[shuffle_index]
y = y[shuffle_index]


print(np.unique(y, return_counts=True))

train_x_bin, val_x_bin, train_y, val_y = train_test_split(x_bin, y, test_size=0.2, random_state=seed, shuffle=True)

(array([0, 1], dtype=uint32), array([12500, 12500]))


### Install the green-tsetlin package using pip

In [6]:
# pip install green-tsetlin

In [None]:
from green_tsetlin.sparse_tsetlin_machine import SparseTsetlinMachine
from green_tsetlin.trainer import Trainer

n_clauses = 1000
s = 2.0
threshold = 2000
boost_true_positive_feedback = 1
literal_budget = 6
dynamic_AL = True
AL_size = 100
clause_size = 50
lower_ta_threshold = -40

tm = SparseTsetlinMachine(n_literals=train_x_bin.shape[1],
                          n_clauses=n_clauses,
                          s=s,
                          threshold=threshold,
                          boost_true_positive=boost_true_positive_feedback,
                          literal_budget=literal_budget,
                          dynamic_AL=dynamic_AL)
