In [16]:
import numpy as np 
import pandas as pd
from scipy.sparse import vstack
from datasets import load_dataset, DatasetDict, Dataset
from sklearn.feature_extraction.text import HashingVectorizer
from sklearn.preprocessing import OneHotEncoder
from tokenwiser.pipeline import make_partial_union

In [17]:
class ClassificationDataset:
    def __init__(self, path, text_col='text', label_col='label'):
        dataf = pd.read_csv(path)
        self.train = dataf.loc[lambda d: d['split'] == 'train'].reset_index()
        self.valid = dataf.loc[lambda d: d['split'] == 'valid'].reset_index()
        self.labels = list(dataf[label_col].unique())
        self.text_col = text_col
        self.label_col = label_col
        self.name = path + ("" if label_col == "label" else "-" + label_col)
        label_arr = np.array(self.labels).reshape(-1, 1)
        self.label_enc = OneHotEncoder(sparse=False).fit(label_arr)
    
    def batch(self, n):
        """Generates (texts, labels) datasets"""
        indices = np.random.randint(len(self.train), size=n)
        subset = self.train.iloc[indices]
        label_arr = np.array(subset[self.label_col]).reshape(-1, 1)
        return subset[self.text_col], self.label_enc.transform(label_arr)
    
    def full(self, split="train"):
        subset = self.train if split == "train" else self.valid
        label_arr = np.array(subset[self.label_col]).reshape(-1, 1)
        return subset[self.text_col], self.label_enc.transform(label_arr)


class Batcher:
    def __init__(self, dataset, tokeniser):
        self.dataset = dataset
        self.tokeniser = tokeniser
        label_arr = np.array(self.dataset.labels).reshape(-1, 1)
    
    def __repr__(self):
        return f"<Batcher {self.dataset.name} n={self.dataset.train.shape[0]} k={len(self.dataset.labels)}>"
    
    def batch(self, n):
        text, labs = self.dataset.batch(n=n)
        return self.tokeniser.transform(text), labs
    
    def batch_X_s_y(self, n):
        """Generates (texts, label_ids, similarity) datasets"""
        X, y = self.batch(n)
        X_out, s_out, y_out = [], [], []
        for row_idx, text in enumerate(X):
            for idx, lab in enumerate(y[row_idx]):
                X_out.append(text)
                s_out.append([0 if i != idx else 1 for i in range(y.shape[1])])
                y_out.append(lab)
        return vstack(X_out), np.array(s_out), np.array(y_out)
    
    def full(self, split="train"):
        text, labs = self.dataset.full(split=split)
        return self.tokeniser.transform(text), labs
    
    def transform(self, X):
        return self.tokeniser.transform(text)

In [18]:
# class TextDataset(Dataset):
#     def __init__(self, name='silicone', subset='dyda_da', split='train', n_feat=5_000):
#         self.dataset = load_dataset(name, subset)
#         if isinstance(self.dataset, DatasetDict):
#             self.dataset = self.dataset[split]
#         self.labels = list(set(i['Label'] for i in self.dataset))
#         self.name = f"{name}-{subset}-{split}"
#         self.tfm = make_partial_union(
#             HashingVectorizer(n_features=n_feat),
#             HashingVectorizer(n_features=n_feat, ngram_range=(2, 2)),
#         )
#         self.label_enc = OneHotEncoder(sparse=False).fit(np.array(self.labels).reshape(-1, 1))

#     def __len__(self):
#         return len(self.dataset)

#     def __getitem__(self, idx):
#         item = self.dataset[idx]
#         return item['Utterance'], item['Label']

#     def __repr__(self):
#         return f"<TextDataset {self.name}>"
    
#     def batch_X_y(self, n):
#         """Samples a random batch of `n` datapoints."""
#         indices = np.random.randint(len(self), size=n)
#         texts, labels = zip(*[self[int(i)] for i in indices])
#         X = self.tfm.transform(texts)
#         y = self.label_enc.transform(np.array(labels).reshape(-1, 1))
#         return X, y
    
#     def batch_X_s_y(self, n):
#         X, y = self.batch_X_y(n)
#         res = []
#         for text in X:
#             for idx, lab in enumerate(y):
#                 res.append(text, idx, np.argmax(lab))
#         return res
    
#     def full(self):
#         """Returns the full set in matrix form."""
#         texts, labels = zip(*[self[int(i)] for i in range(len(self))])
#         X = self.tfm.transform(texts)
#         y = self.label_enc.transform(np.array(labels).reshape(-1, 1))
#         return X, y
    
#     def transform(self, texts):
#         return self.tfm.transform(texts)

In [19]:
ClassificationDataset("data/silicone-dyda_da.csv").batch(2)

(22374                    what do you mean by that ?
 74382    i'm here to see about a fixed asset loan .
 Name: text, dtype: object,
 array([[0., 0., 0., 1.],
        [0., 1., 0., 0.]]))

In [20]:
n_feat = 2000

tok = tfm = make_partial_union(
            HashingVectorizer(n_features=n_feat),
            HashingVectorizer(n_features=n_feat, ngram_range=(2, 2)),
        )
    
my_datasets = {d.name: {'dataset': Batcher(d, tokeniser=tok)} for d in [
    ClassificationDataset("data/silicone-dyda_da.csv"),
    ClassificationDataset("data/silicone-dyda_e.csv"), 
    ClassificationDataset("data/silicone-meld_e.csv"),
    ClassificationDataset("data/tweet_eval-emoji.csv"),
    ClassificationDataset("data/tweet_eval-emotion.csv"),
    ClassificationDataset("data/tweet_eval-sentiment.csv"),
    ClassificationDataset("data/google-emotions.csv", label_col="anger"),
]}

```python
fuse = (
    FUSE(tokeniser, n_tok_feat)
      .add_task(name, subset)
      .add_task(name, subset)
      .add_task(name, subset)
)
```

In [21]:
my_datasets

{'data/silicone-dyda_da.csv': {'dataset': <Batcher data/silicone-dyda_da.csv n=87170 k=4>},
 'data/silicone-dyda_e.csv': {'dataset': <Batcher data/silicone-dyda_e.csv n=87170 k=7>},
 'data/silicone-meld_e.csv': {'dataset': <Batcher data/silicone-meld_e.csv n=9989 k=7>},
 'data/tweet_eval-emoji.csv': {'dataset': <Batcher data/tweet_eval-emoji.csv n=45000 k=20>},
 'data/tweet_eval-emotion.csv': {'dataset': <Batcher data/tweet_eval-emotion.csv n=3257 k=4>},
 'data/tweet_eval-sentiment.csv': {'dataset': <Batcher data/tweet_eval-sentiment.csv n=45615 k=3>},
 'data/google-emotions.csv-anger': {'dataset': <Batcher data/google-emotions.csv-anger n=169046 k=2>}}

In [63]:
import keras
from keras.layers import Dense, Input, Embedding, Dot
from keras.models import Model
import scipy
import numpy as np

X = tok.transform(["hello"])
emb_dim = 256
inputs = Input(shape=(X.shape[1],), sparse=True, name="text_input")
emb1 = Dense(emb_dim, activation='relu', name="dense_layer_1")(inputs)
emb2 = Dense(emb_dim, activation='relu', name="dense_layer_2")(emb1)

for dataset in my_datasets.values():
    X, s, y = dataset['dataset'].batch_X_s_y(8)
    
    label_shape = len(dataset['dataset'].dataset.labels)
    label_input = Input(shape=(s.shape[1],), name="label_input")
    label_emb   = Dense(emb_dim, name="label_emb")(label_input)
    dot_prod    = Dot(axes=(1,1), name="dot_product")([label_emb, emb2])
    output      = Dense(1, activation='sigmoid', name="task_output")(dot_prod)
    
    dataset['outputs'] = output
    dataset['model'] = Model(inputs=[inputs, label_input], outputs=dataset['outputs'])
    dataset['model'].compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])


In [65]:
from sklearn.utils import class_weight


for epoch in range(100):
    for dataset in my_datasets.keys():
        X, s, y = my_datasets[dataset]['dataset'].batch_X_s_y(2048*2)
        print(dataset)
        valid_data = my_datasets[dataset]['dataset'].full(split="valid")
        labels = sorted(my_datasets[dataset]['dataset'].dataset.labels)
        weights = class_weight.compute_class_weight('balanced', classes=labels, y=np.argmax(s, axis=1))
        my_datasets[dataset]['model'].fit(x=[X, s], 
                                          y=y, 
                                          batch_size=128, 
                                          epochs=1, 
                                          class_weight=dict(enumerate(weights)))

data/silicone-dyda_da.csv
data/silicone-dyda_e.csv
data/silicone-meld_e.csv
data/tweet_eval-emoji.csv
data/tweet_eval-emotion.csv
data/tweet_eval-sentiment.csv
data/google-emotions.csv-anger
data/silicone-dyda_da.csv
data/silicone-dyda_e.csv
data/silicone-meld_e.csv
data/tweet_eval-emoji.csv
data/tweet_eval-emotion.csv
data/tweet_eval-sentiment.csv
data/google-emotions.csv-anger
data/silicone-dyda_da.csv
data/silicone-dyda_e.csv
data/silicone-meld_e.csv
data/tweet_eval-emoji.csv
data/tweet_eval-emotion.csv
data/tweet_eval-sentiment.csv
 7/96 [=>............................] - ETA: 0s - loss: 0.5334 - accuracy: 0.7288

KeyboardInterrupt: 

In [66]:
emb_model = Model(inputs=inputs, outputs=emb2)
pd.DataFrame(emb_model.predict(tok.transform(["bad", "evil", "good", "joy", "happy"]))).T.corr()

Unnamed: 0,0,1,2,3,4
0,1.0,0.623198,0.081766,0.138506,0.005683
1,0.623198,1.0,-0.040192,0.03602,-0.033577
2,0.081766,-0.040192,1.0,0.830557,0.932096
3,0.138506,0.03602,0.830557,1.0,0.772061
4,0.005683,-0.033577,0.932096,0.772061,1.0
