In [1]:
import pickle

from datasets import ClassLabel, Dataset, concatenate_datasets, load_dataset
from transformers import pipeline
from setfit import AbsaModel

from divide_and_conquer_sentiment import PolaritySentimentModel
from divide_and_conquer_sentiment.aggregation import MLP, MLPAggregator
from divide_and_conquer_sentiment.dataloaders import load_kaggle_dataset
from divide_and_conquer_sentiment.subprediction import ABSASubpredictor, ChunkSubpredictor
from divide_and_conquer_sentiment.aggregation.sawon import SawonAggregator
from divide_and_conquer_sentiment.subprediction.sentence import Chunker

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
TRAIN_BATCH_SIZE = 256

# Read & prepare dataset


In [2]:
twitter_airlines_dataset = load_kaggle_dataset(
    "crowdflower/twitter-airline-sentiment",
    {"text": "text", "airline_sentiment": "label"},
    val_test_perc=(0.1, 0.2),
    seed=42,
)

Dataset URL: https://www.kaggle.com/datasets/crowdflower/twitter-airline-sentiment


  response_data.getheaders())


Generating train split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/14640 [00:00<?, ? examples/s]

Map:   0%|          | 0/14640 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/14640 [00:00<?, ? examples/s]

In [3]:
amazon_headphones_dataset = load_kaggle_dataset(
    "mdwaquarazam/headphone-dataset-review-analysis",
    {"COMMENTS": "text", "RATINGS": "label"},
    val_test_perc=(0.1, 0.2),
    seed=42,
)

Dataset URL: https://www.kaggle.com/datasets/mdwaquarazam/headphone-dataset-review-analysis


  response_data.getheaders())


Generating train split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/1604 [00:00<?, ? examples/s]

Map:   0%|          | 0/1604 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/1604 [00:00<?, ? examples/s]

In [4]:
SST_LABEL_MAP = {0: 0, 1: 0, 2: 1, 3: 2, 4: 2}


def map_sst_label(x):
    x["label"] = SST_LABEL_MAP[x["label"]]
    return x


sst_dataset = (
    load_dataset("SetFit/sst5")
    .remove_columns(["label_text"])
    .map(map_sst_label)
    .cast_column("label", ClassLabel(names=["negative", "neutral", "positive"]))
)

Repo card metadata block was not found. Setting CardData to empty.


In [5]:
train_dataset = concatenate_datasets(
    [twitter_airlines_dataset["train"], amazon_headphones_dataset["train"], sst_dataset["train"]]
)
val_dataset = concatenate_datasets(
    [twitter_airlines_dataset["val"], amazon_headphones_dataset["val"], sst_dataset["validation"]]
)
test_dataset = concatenate_datasets(
    [twitter_airlines_dataset["test"], amazon_headphones_dataset["test"], sst_dataset["test"]]
)

In [None]:
for i in range(len(val_dataset["text"])):
    if val_dataset["text"][i] == "":
        print(val_dataset[i])


# Train MLP on ABSA model

In [None]:
!spacy download en_core_web_lg

In [None]:
subpredictor = ABSASubpredictor.from_pretrained(
    "tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-aspect",
    "tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-polarity",
    spacy_model="en_core_web_lg",
)
mlp = MLP(input_size=4, output_size=3, hidden_layer_sizes=(128, 64), lr=0.01)
aggregator = MLPAggregator(mlp)

In [None]:
train_subpreds = subpredictor.predict(train_dataset["text"])
val_subpreds = subpredictor.predict(val_dataset["text"])

In [None]:
train_dataset = concatenate_datasets([train_dataset, Dataset.from_dict({"subpreds": train_subpreds})], axis=1)
val_dataset = concatenate_datasets([val_dataset, Dataset.from_dict({"subpreds": val_subpreds})], axis=1)

In [None]:
with open("train_dataset_subpreds.pkl", "wb") as handle:
    pickle.dump(train_dataset, handle)

with open("val_dataset_subpreds.pkl", "wb") as handle:
    pickle.dump(val_dataset, handle)

In [None]:
with open("train_dataset_subpreds.pkl", "rb") as handle:
    train_dataset = pickle.load(handle).with_format("torch")

with open("val_dataset_subpreds.pkl", "rb") as handle:
    val_dataset = pickle.load(handle).with_format("torch")

In [None]:
aggregator.train(train_dataset, val_dataset)

# Calc SAWON

In [None]:
!spacy download en_core_web_lg

In [6]:
polarity_model = AbsaModel.from_pretrained(
    "tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-polarity",
    spacy_model="en_core_web_lg",
).polarity_model

INFO:sentence_transformers.SentenceTransformer:Use pytorch device_name: mps
INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-polarity
Overriding labels in model configuration from None to ['no aspect', 'aspect'].
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
INFO:sentence_transformers.SentenceTransformer:Use pytorch device_name: mps
INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-polarity


In [7]:
polarity_sentiment_model = PolaritySentimentModel(polarity_model)

In [8]:
chunker = Chunker()
sentence_subpredictor = ChunkSubpredictor(chunker, polarity_sentiment_model)

In [14]:
sentences_train_preds = sentence_subpredictor.predict(train_dataset["text"])



In [9]:
sentences_val_preds = sentence_subpredictor.predict(val_dataset["text"])



In [19]:
for i in range(len(val_dataset["text"])):
    if val_dataset["text"][i] == "":
        print(val_dataset[i])

{'text': '', 'label': 2}
{'text': '', 'label': 0}
{'text': '', 'label': 2}
{'text': '', 'label': 0}
{'text': '', 'label': 2}
{'text': '', 'label': 2}


In [11]:
sawon = SawonAggregator(polarity_sentiment_model,0.9)

In [15]:
sawon_train_preds = sawon.aggregate(sentences_train_preds, passages =  train_dataset["text"])

In [13]:
sawon_val_preds = sawon.aggregate(sentences_val_preds, passages = val_dataset["text"])

In [16]:
with open("sawon_train_preds.pkl", "wb") as handle:
    pickle.dump(sawon_train_preds, handle)

with open("sawon_val_preds.pkl", "wb") as handle:
    pickle.dump(sawon_val_preds, handle)

# Train MLP on sentences

In [24]:
mlp_sentences = MLP(input_size=3, output_size=3, hidden_layer_sizes=(128, 64), lr=0.01)
mlp_aggregator_sentences = MLPAggregator(mlp_sentences)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [25]:
train_sentences_dataset = concatenate_datasets([train_dataset, Dataset.from_dict({"subpreds": sentences_train_preds})], axis=1)
val_sentences_dataset = concatenate_datasets([val_dataset, Dataset.from_dict({"subpreds": sentences_val_preds})], axis=1)

Flattening the indices:   0%|          | 0/19914 [00:00<?, ? examples/s]

Flattening the indices:   0%|          | 0/2726 [00:00<?, ? examples/s]

In [26]:
mlp_aggregator_sentences.train(train_sentences_dataset, val_sentences_dataset)


  | Name   | Type       | Params | Mode 
----------------------------------------------
0 | layers | ModuleList | 10.6 K | train
----------------------------------------------
10.6 K    Trainable params
0         Non-trainable params
10.6 K    Total params
0.043     Total estimated model params size (MB)
4         Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/pawel.marcinkowski/.pyenv/versions/venv-spacy/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined