<a href="https://colab.research.google.com/github/madziejm/1e100-ibu/blob/master/1e100ibu.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Preliminary

#### Dependencies

In [None]:
import torch
dev = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f'dev = {dev}')

dev = cpu


In [None]:
try: # mount user's Google Drive if on Colab to save training artifacts
    from google.colab import drive
    drive.mount('/drive')
    ROOT_DIR = './'
    MODEL_ROOT_DIR = '/drive/MyDrive/Colab Notebooks/1e100ibu/saves/'
except ImportError:
    ROOT_DIR = '/content/'
    MODEL_ROOT_DIR = './saves/'

In [None]:
!pip install --quiet icecream
from icecream import ic

In [None]:
# !pip install --quiet -Iv torch==1.10.1
# !pip install --quiet -Iv torchtext==0.11.1

## Dataset representation

In [None]:
!pip install 'spacy<3.3.0,>=3.2.0' --quiet
!python -m spacy download en_core_web_sm --quiet
!python -m spacy download pl_core_news_md --quiet

Collecting en-core-web-sm==3.2.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.2.0/en_core_web_sm-3.2.0-py3-none-any.whl (13.9 MB)
[K     |████████████████████████████████| 13.9 MB 5.3 MB/s 
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')
Collecting pl-core-news-md==3.2.0
  Downloading https://github.com/explosion/spacy-models/releases/download/pl_core_news_md-3.2.0/pl_core_news_md-3.2.0-py3-none-any.whl (87.9 MB)
[K     |████████████████████████████████| 87.9 MB 21 kB/s 
Installing collected packages: pl-core-news-md
Successfully installed pl-core-news-md-3.2.0
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('pl_core_news_md')


In [None]:
!pip show spacy | egrep Version
# we want SpaCy 3
!pip show torch | egrep Version
!pip show torchtext | egrep Version

Version: 3.2.1
Version: 1.10.0+cu111
Version: 0.11.0


#### dataset representation

In [None]:
from collections import Counter
from torchtext._torchtext import (Vocab as VocabPybind) # make use of some hidden interface
from torchtext.vocab import Vocab, build_vocab_from_iterator
from tqdm.notebook import trange, tqdm
import gc # garbage collector interface
import io
import re
import spacy # nlp toolkit
import torch
import pickle


class BaseReviews(torch.utils.data.Dataset):
    def __init__(self, aspects, aspect_max, aspect_ratings, texts, unkn_tok, _len, anchor_words):
        self.aspects = aspects
        self.aspect_count = len(aspects)
        self.aspect_max = aspect_max
        self._aspect_ratings = aspect_ratings
        self._texts = texts
        self.unkn_tok = unkn_tok
        self._len = _len
        self.anchor_words = anchor_words
        self.vocab = None

    def dump(self, dest_path, filename):
        contents = {
            'aspects'        : self.aspects,
            'aspect_max'     : self.aspect_max,
            '_aspect_ratings': self._aspect_ratings,
            '_texts'         : self._texts,
            'unkn_tok'       : self.unkn_tok,
            '_len'           : self._len,
            'anchor_words'   : self.anchor_words,
            'vocab'          : self.vocab,
        }
        with open(f'{dest_path}/{filename}', 'wb') as f:
            pickle.dump(contents, f)
    
    def load(self, dest_path, filename):
        with open(f'{dest_path}/{filename}', 'rb') as f:
            contents = pickle.load(f)
            self.aspects        = contents['aspects']
            self.aspect_max     = contents['aspect_max']
            self._aspect_ratings = contents['_aspect_ratings']
            self._texts          = contents['_texts']
            self.unkn_tok       = contents['unkn_tok']
            self._len           = contents['_len']
            self.anchor_words   = contents['anchor_words']
            self.vocab          = contents['vocab']

    def __getitem__(self, i):
        # # 1 # python
        # sentences = tuple(sent for sent in self._texts[i])
        # ratings = tuple(self._aspect_ratings[a][i] for a in range(self.aspect_count))
        # 2 # tensor
        sentences = tuple(torch.LongTensor(sent) for sent in self._texts[i])
        ratings = torch.LongTensor(tuple(self._aspect_ratings[a][i] for a in range(self.aspect_count)))
        # # 3 # dev
        # sentences = tuple(torch.tensor(sent) for sent in self._texts[i])
        # ratings = torch.tensor(tuple(self._aspect_ratings[a][i] for a in range(self.aspect_count)))
        return (sentences, ratings)

    def __len__(self):
        return self._len

In [None]:
class RateBeerReviews(BaseReviews):
    """
    beer/name: John Harvards Simcoe IPA
    beer/beerId: 63836
    beer/brewerId: 8481
    beer/ABV: 5.4
    beer/style: India Pale Ale &#40;IPA&#41;
    review/appearance: 4/5
    review/aroma: 6/10
    review/palate: 3/5
    review/taste: 6/10
    review/overall: 13/20
    review/time: 1157587200
    review/profileName: hopdog
    review/text: On tap at the Springfield, PA location. Poured a deep and cloudy orange (almost a copper) color with a small sized off white head. Aromas or oranges and all around citric. Tastes of oranges, light caramel and a very light grapefruit finish. I too would not believe the 80+ IBUs - I found this one to have a very light bitterness with a medium sweetness to it. Light lacing left on the glass.
    """
    def __init__(self):
        aspects = ['appearance', 'aroma', 'palate', 'taste', 'overall']
        super().__init__(
            aspects        = aspects,
            aspect_max     = [5, 10, 5, 10, 20],
            aspect_ratings = [ [] for _ in aspects ],
            texts          = [],
            unkn_tok       = '<unk>', # unknown/out of vocabulary token
            _len            = 0,
            anchor_words = {
                'appearance' : ('appearance', 'color'),
                'aroma'      : ('aroma'),
                'palate'     : ('palate', 'mouthfeel'),
                'taste'      : ('taste'),
                'overall'    : ('overall'),
            },
        )
        self.pipe = None

    def build(self, filepath=f'{ROOT_DIR}/SNAP-Ratebeer.txt', max_reviews=float('inf'), min_word_freq=None, max_word_count=None):
        with io.open(filepath, encoding='utf-8') as f:
            for line in tqdm(f, total=(40938282 if max_reviews == float('inf') else max_reviews * 14), desc='Reading data'):
                if line == '\n': # separator
                    self._len += 1
                    if max_reviews <= self._len:
                        break
                elif line.startswith('review/appearance: '):
                    line = line[len('review/appearance: '):]
                    self._aspect_ratings[0].append(int(line.split('/')[0])) # lhs of split by '/' is rating, rhs is max possible rating
                elif line.startswith('review/aroma: '):
                    line = line[len('review/aroma: '):]
                    self._aspect_ratings[1].append(int(line.split('/')[0])) # lhs of split by '/' is rating, rhs is max possible rating
                elif line.startswith('review/palate: '):
                    line = line[len('review/palate: '):]
                    self._aspect_ratings[2].append(int(line.split('/')[0])) # lhs of split by '/' is rating, rhs is max possible rating
                elif line.startswith('review/taste: '):
                    line = line[len('review/taste: '):]
                    self._aspect_ratings[3].append(int(line.split('/')[0])) # lhs of split by '/' is rating, rhs is max possible rating
                elif line.startswith('review/overall: '):
                    line = line[len('review/overall: '):]
                    self._aspect_ratings[4].append(int(line.split('/')[0])) # lhs of split by '/' is rating, rhs is max possible rating
                elif line.startswith('review/text: '):
                    line = line[len('review/text: '):]
                    if line.startswith('UPDATED:'):
                        line = line[len("UPDATED: APR 29, 2008"):] # drop prefix
                    line = re.sub('~', ' ', line.strip()) # remove whitespace incl. trailing newline and tildes that can be found in data for some reason
                    if line:
                        self._texts.append(line)
                    else: # some reviews do not have associated text; unwind (remove) their ratings for each aspect
                        for aspect_ratings in self._aspect_ratings:
                            aspect_ratings.pop()
                        self._len -= 1
        self._post_process(min_word_freq, max_word_count) # 20K words should be okay
    
    def _fetch_nlp_pipeline(self):
        if not self.pipe:
            nlp = spacy.util.get_lang_class('en')()
            nlp.add_pipe("sentencizer", config={"punct_chars": ['.', '?', '!']})
            nlp.Defaults.stop_words |= { '-', '+'}
            nlp.Defaults.stop_words -= {'mostly', 'whole', 'indeed', 'quite', 'ever', 'nothing', 'perhaps', 'not', 'no', 'only', 'well', 'really', 'except'}
            self.pipe = lambda reviews: nlp.pipe(reviews)
    
    def _free_nlp_pipeline(self):
        self.nlp = None

    def tokenize_reviews(self, reviews_texts: str):
        return [tuple(list(tok.lower_ for tok in sent if not tok.is_stop and not tok.is_punct and not tok.is_space and len(tok) > 2) for sent in doc.sents if 0 != len(sent)) for doc in self.pipe(reviews_texts)]
    
    def id_map_reviews(self, texts):
        return [tuple(self.vocab.lookup_indices(sent) for sent in text) for text in texts]
    
    def _post_process(self, min_word_freq=None, max_word_count=None):
        assert (min_word_freq is not None) ^ bool(max_word_count is not None), "provide one of min_word_freq and max_word_count"
        self._fetch_nlp_pipeline()
        print("Spacy pipe (tokenization&sentence split)..")
        gc.collect() # force garbage collection
        self._texts = self.tokenize_reviews(self._texts)
        for i, text in enumerate(self._texts):
            assert 0 != len(text) # make sure no empty reviews again (new could be introduced by removing stop words unfortunately)
        print("Building vocab (word-id mapping)..")
        gc.collect() # force garbage collection
        sent_gen = (sent for text in self._texts for sent in text)
        if min_word_freq:
            self.vocab = build_vocab_from_iterator(sent_gen, specials=[self.unkn_tok], min_word_freq=5)
        else:
            words = Counter()
            for tokens in sent_gen:
                words.update(tokens)
            words = [word for word, freq in words.most_common(max_word_count)] # list sorted by frequency yikees
            self.vocab = Vocab(VocabPybind(words, None))
        self.vocab.insert_token(self.unkn_tok, 0)
        self.vocab.set_default_index(self.vocab[self.unkn_tok]) # set index for out-of-vocabulary words
        print("Mapping words to ids..")
        gc.collect() # force garbage collection
        self._texts = self.id_map_reviews(self._texts)
        gc.collect() # force garbage collection

If you want to read dataset from dataset file, set USE_RATEBEER_PICKLE to true in the cell below and RECREATE_PICKLE to True. If you left them untouched, it'lle be read from serialized `RateBeerReviews` class object instead of parsing text file.

In [None]:
%env USE_RATEBEER_PICKLE=false

env: USE_RATEBEER_PICKLE=false


In [None]:
%%bash

if [ "$USE_RATEBEER_PICKLE" = true ]
then # download pickle
    if [ ! -f './ratebeer-20K-vocab.pickle' ]
    then
        gdown --id '1VBDjyR4jpzAgzcDUGNQFguOfLC3rtOV_' # https://drive.google.com/file/d/1VBDjyR4jpzAgzcDUGNQFguOfLC3rtOV_/view?usp=sharing  # 20K words dataset
        # gdown --id '1ebDMDlOxtFh8B5i8lajR7q3kq-0hM02j' # https://drive.google.com/file/d/1ebDMDlOxtFh8B5i8lajR7q3kq-0hM02j/view?usp=sharing # min frequency 5 words dataset
    fi
else # download original dataset
    if [ ! -f './SNAP-Ratebeer.txt' ]
    then
        gdown --id '12tEEYQcHZtg5aWyfIiWWVIDAJNT-5d_T' # https://drive.google.com/file/d/12tEEYQcHZtg5aWyfIiWWVIDAJNT-5d_T/view?usp=sharing
        echo "Dataset head (trailing newline makes entry end): "
        head -n 16 $RATEBEER_FILE
        iconv -f ISO-8859-1 -t UTF-8 $RATEBEER_FILE -o {RATEBEER_FILE}.new && mv {RATEBEER_FILE}.new $RATEBEER_FILE
    fi
fi

In [None]:
rb = RateBeerReviews()

if os.environ.get('USE_RATEBEER_PICKLE') == 'true':
    rb.load('./', 'ratebeer-20K-vocab.pickle')
else: # build pickle
    rb.build('./SNAP-Ratebeer.txt', max_word_count=20000)
    print('Dumping..')
    rb.dump('./', 'ratebeer-20K-vocab.pickle')

Reading data:   0%|          | 0/40938282 [00:00<?, ?it/s]

Spacy pipe (tokenization&sentence split)..
Building vocab (word-id mapping)..
Mapping words to ids..
Dumping..


In [None]:
from collections import Counter
from torchtext._torchtext import (Vocab as VocabPybind) # make use of some hidden interface
from torchtext.vocab import Vocab, build_vocab_from_iterator
from tqdm.notebook import trange, tqdm
import gc # garbage collector interface
import io
import re
import spacy # nlp toolkit
import torch
import json

class OcenPiwoReviews(torch.utils.data.Dataset):
    def __init__(self, filepath='/content/ocen-piwo-utf8.json', reviews_max=float('inf')):
        self.aspects = ['ogólny', 'smak', 'zapach', 'wygląd',]
        self.aspect_count = len(self.aspects)
        self.aspect_max = [10, 10, 10, 10]
        self._aspect_ratings = [ [] for _ in self.aspects ]
        self._texts = []
        self.unkn_tok = "<unk>" # unknown/out of vocabulary token
        self._len = 0
        self._fetch_data(filepath, reviews_max)
        self._post_process(max_word_count=20000) # 20K words should be okay

    def _fetch_data(self, filepath, reviews_max):
        with io.open(filepath, encoding='utf-8') as f:
            json_dict = json.loads(f.read())

            for i, reviews in enumerate(json_dict.values()):
                for sentences, ratings in reviews:
                    self._len += 1

                    for aspect in range(self.aspect_count):
                        self._aspect_ratings[aspect].append(ratings[aspect])

                    self._texts.append(sentences)

    def _post_process(self, min_word_freq=None, max_word_count=None):
        nlp = spacy.load('pl_core_news_md')
        nlp.add_pipe("sentencizer", config={"punct_chars": ['.', '?', '!']})
        nlp.Defaults.stop_words |= { '-', '+', }
        print("Spacy pipe (tokenization&sentence split)..")

        gc.collect() # force garbage collection
        self._texts = [tuple(list(tok.lower_ for tok in sent if not tok.is_stop and not tok.is_punct and not tok.is_space and len(tok) > 2) 
            for sent in doc.sents if 0 != len(sent)) for doc in nlp.pipe(self._texts)]

        for i, text in enumerate(self._texts):
            assert 0 != len(text) # make sure no empty reviews again (new could be introduced by removing stop words unfortunately)

        print("Building vocab (word-id mapping)..")
        gc.collect() # force garbage collection
        sent_gen = (sent for text in self._texts for sent in text)

        if min_word_freq:
            self.vocab = build_vocab_from_iterator(sent_gen, specials=[self.unkn_tok], min_word_freq=5)
        else:
            words = Counter()
            for tokens in sent_gen:
                words.update(tokens)
            words = [word for word, freq in words.most_common(max_word_count)] # list sorted by frequency yikees
            self.vocab = Vocab(VocabPybind(words, None))
        self.vocab.insert_token(self.unkn_tok, 0)
        self.vocab.set_default_index(self.vocab[self.unkn_tok]) # set index for out-of-vocabulary words
        print("Mapping words to ids..")
        gc.collect() # force garbage collection
        self._texts = [tuple(self.vocab.lookup_indices(sent) for sent in text) for text in self._texts]
        gc.collect() # force garbage collection

    def __getitem__(self, i):
        # # 1 # python
        # sentences = tuple(sent for sent in self._texts[i])
        # ratings = tuple(self._aspect_ratings[a][i] for a in range(self.aspect_count))
        # 2 # tensor
        sentences = tuple(torch.LongTensor(sent) for sent in self._texts[i])
        ratings = torch.LongTensor(tuple(self._aspect_ratings[a][i] for a in range(self.aspect_count)))
        # # 3 # dev
        # sentences = tuple(torch.tensor(sent) for sent in self._texts[i])
        # ratings = torch.tensor(tuple(self._aspect_ratings[a][i] for a in range(self.aspect_count)))
        return (sentences, ratings)

    def __len__(self):
        return self._len

In [None]:
!gdown --id '1RM_Sk8QeOQnjnje0gwxQfJIOIK0KLLWV'

Downloading...
From: https://drive.google.com/uc?id=1RM_Sk8QeOQnjnje0gwxQfJIOIK0KLLWV
To: /content/ocen-piwo-utf8.json
100% 29.5M/29.5M [00:00<00:00, 71.3MB/s]


In [None]:
op = OcenPiwoReviews()

Spacy pipe (tokenization&sentence split)..
Building vocab (word-id mapping)..
Mapping words to ids..


In [None]:
DATASET_PICKLE='/content/ocenpiwo-20K-vocab.pickle'

with open(DATASET_PICKLE, 'wb') as f:
    print('Dumping..')
    torch.save(op, f)

Dumping..


If you want to read dataset from dataset file, set FETCH_RATEBEER to true in the cell below and RECREATE_PICKLE to True. If you left them untouched, it'lle be read from serialized `RateBeerReviews` class object instead of parsing text file.

### Training (implementation of $(1)$)

In [None]:
from torch.utils.data import random_split, DataLoader
import datetime
from scipy.optimize import linear_sum_assignment
from more_itertools import grouper

# for wordcloudsdest_path=dest_path
import functools
import numpy as np
import os
import re
from PIL import Image
from os import path
from wordcloud import WordCloud
import matplotlib.pyplot as plt
import functools
from operator import iadd

class Model():
    def __init__(self, dataset):
        self.ds = dataset
        self.init_weights()

    def init_weights(self):
        word_count = len(self.ds.vocab.get_itos())
        self.theta = torch.rand((word_count, self.ds.aspect_count), device=dev)
        # scale to [-0.1, 0.0], as we enforce this weight to 1.0 for some words later on
        self.theta = self.theta * -0.1
        # scale to [0.0, 0.9], as we enforce this weight to 1.0 for some words later on
        # self.theta = self.theta * 0.9
        # enforce 1 initialization on aspect name (page 4)
        # aspect_ids = self.ds.vocab.lookup_indices(self.ds.aspects)
        for aspect_idx, aspect in enumerate(self.ds.aspects):
            words = self.ds.anchor_words[aspect]

            if isinstance(words, str):
                words = [words]
            else:
                words = list(words)

            words_ids = self.ds.vocab.lookup_indices(words)
            for word_id in words_ids:
                self.theta[word_id, aspect_idx] = 1.0
        self.theta.requires_grad_()

        # introduce separate phi for each aspect
        # self.phis = [torch.rand((word_count, self.ds.aspect_max[i])).to(dev) for i in range(self.ds.aspect_count)]
        self.phis = [torch.zeros((word_count, self.ds.aspect_max[i]), device=dev, dtype=self.theta.dtype) for i in range(self.ds.aspect_count)]
        # # normalize that sum across all words is 1 for a given aspect (eq. 7) # do not normalize as for now
        # self.phis = [phi / phi.sum(dim=0) for phi in self.phis]
        for phi in self.phis: phi.requires_grad_()
    
    def word_clouds(self, dest_path='/drive/MyDrive/Colab Notebooks/1e100ibu/saves/', filename='words.png', show=True):
        words = self.ds.vocab.get_itos()
        fig = plt.figure(figsize=(21, 5))
        plt.subplots_adjust(wspace=0.2)
        # plt.subplots_adjust(wspace=0.01, hspace=0.000000001)
        # plt.tight_layout()
        i = 0

        for aspect in range(self.ds.aspect_count):
            # print()
            aspect_name = self.ds.aspects[aspect]
            # print(aspect_name)
            atheta = self.theta[:, aspect].tolist()
            
            zipped = list(zip(words, atheta))
            # ic(list(filter(lambda x: x[0] == 'antrunk', zipped)))
            # sorted_zip = sorted(zipped, reverse=True, key=lambda x: x[1])[:50]
            # print(*[(word, '%s' % float('%.3g' % weight)) for (word, weight) in sorted_zip])

            wc = WordCloud(background_color="white", color_func=lambda *args, **kwargs: 'black', width=1000, height=1000)
            wc.generate_from_frequencies(dict(zipped))

            fig.add_subplot(self.ds.aspect_count, max(self.ds.aspect_max) + 1, i * (max(self.ds.aspect_max) + 1) + 1)
            plt.imshow(wc)
            plt.title(f'{aspect_name} Theta', fontsize=2)
            plt.axis("off")

            for rating in range(self.ds.aspect_max[aspect]):
                aphi = self.phis[aspect][:, rating].tolist()

                zipped = list(zip(words, aphi))
                sorted_zip = sorted(zipped, reverse=True, key=lambda x: x[1])[:200]

                if any(a == 0 for _, a in sorted_zip):
                    print(f'Omitting {aspect_name} Phi {str(rating + 1)}')
                    continue

                wc = WordCloud(background_color="white", color_func=lambda *args, **kwargs: 'black', width=1000, height=1000)
                wc.generate_from_frequencies(dict(zipped))
                fig.add_subplot(self.ds.aspect_count, max(self.ds.aspect_max) + 1, i * (max(self.ds.aspect_max) + 1) + rating + 2)

                plt.imshow(wc)
                plt.title(f'{aspect_name} Phi {str(rating + 1)}', fontsize=2)
                plt.axis("off")

            i += 1
        plt.savefig(f'{dest_path}/{filename}', dpi=600)
        plt.show(block=show)
    
    def rev_words_thetas(self, rev_sens_ids):
        """
        TODO filename sentence_aspects_likelihood_theta
        """
        return [self.theta[sen_ids] for sen_ids in rev_sens_ids]

    def rev_words_phis(self, rev_sens_ids):
        """
        TODO filename sentence_aspects_likelihood_phi
        """
        return [[self.phis[aspect_idx][sen_ids, :] for aspect_idx in range(self.ds.aspect_count)] for sen_ids in rev_sens_ids]
    
    def dump_weights(self, dest_path='/drive/MyDrive/Colab Notebooks/1e100ibu/saves/', filename=''):
        weights = {'phis': self.phis, 'theta': self.theta}
        torch.save(weights,  f'{dest_path}/{filename}')

    def load_weights(self, src_path):
        weights = torch.load(src_path, map_location=torch.device(dev))
        self.theta = weights['theta']
        self.phis  = weights['phis']
    
    def _linear_assignement(self, costs):
        # for nll we want to minimize
        return linear_sum_assignment(costs, maximize=False)
    
    def plot_nll_history(self, dir, filename):
        plt.figure().set_facecolor('white') # no alpha please
        plt.xlabel('epoch')
        plt.ylabel('NLL')
        plt.xlim(left=0.0, right=self.train_epoch_history[-1])
        plt.ylim(bottom=min(0.0, min(self.train_nll_history)), top=max(max(self.train_nll_history), max(self.test_nll_history)))
        plt.plot(self.train_epoch_history, self.train_nll_history, '-o')
        plt.plot(self.test_epoch_history, self.test_nll_history, '-o')
        plt.legend(['train batch NLL', 'test mean batch NLL'], loc='upper right')
        plt.title('Train and test dataset NLL')
        plt.savefig(os.path.join(dir, filename))
        plt.show()

    @staticmethod
    def experiments(root_dir='/drive/MyDrive/Colab Notebooks/1e100ibu/saves/'):
        for _ in range(1):
            directory = str(datetime.datetime.now()) + '-adam'
            path = os.path.join(root_dir, directory)
            os.mkdir(path)

            model = Model(rb)
            model.train(dest_path=path, optim='adam')

        for _ in range(2):
            directory = str(datetime.datetime.now())
            path = os.path.join(root_dir, directory)
            os.mkdir(path)
            
            model = Model(rb)
            model.train(dest_path=path)

    def test_nll(self, debug=False):
        nlls = []
        with torch.no_grad():
            test_loader = DataLoader(self.test_ds, batch_size=100, collate_fn=lambda x: x) # do not use default collate function as it requires fixed-length input and raises this exception otherwise https://github.com/pytorch/pytorch/issues/42654
            for i, batch in enumerate(tqdm(test_loader, desc='test validation')):
                batch_nlls = []
                for (rev_sents_ids, review_aspects_scores) in batch:
                    rev_thetas = self.rev_words_thetas(rev_sents_ids)
                    rev_phis   = self.rev_words_phis(rev_sents_ids)
                    res_sents_scores = torch.stack(
                        [
                        rev_thetas[j].sum(dim=0) + torch.stack(tuple(rev_phis[j][a][:, review_aspects_scores[a] - 1].sum() for a in range(self.ds.aspect_count))) # 1 x aspect count
                        for j in range(len(rev_sents_ids))
                        ],
                    ) # sent count x aspect count

                    denoms = torch.logsumexp(res_sents_scores, dim=1)[:, None]
                    assert(denoms.shape == (len(rev_sents_ids), 1))

                    res_sents_scores = -res_sents_scores + denoms

                    sents_aspect_preds_max = torch.argmin(res_sents_scores, dim=1)
                    row_ind, col_ind = self._linear_assignement(costs=res_sents_scores.detach().cpu().numpy())
                    sents_aspect_preds_linear = sents_aspect_preds_max

                    sents_aspect_preds_linear[row_ind] = torch.from_numpy(col_ind).to(dev)
                    
                    batch_nlls.append(res_sents_scores.take_along_dim(sents_aspect_preds_linear[:, None], dim=1).sum().cpu().item())
                nlls.append(functools.reduce(lambda a, b: a + b, batch_nlls))
        nll_sum = functools.reduce(lambda a, b: a + b, nlls)
        if debug: ic(nll_sum)
        if debug: ic(torch.exp(-torch.tensor(nll_sum)))
        mean_batch_nll = nll_sum / len(nlls)
        return mean_batch_nll


    def train(self, dest_path='/drive/MyDrive/Colab Notebooks/1e100ibu/saves/', epoch_count=1, optim='sgd'):
        train_size = int(0.8 * len(self.ds))
        test_size = len(self.ds) - train_size

        params = (
            self.theta,
            *self.phis
        )
        lr = 0.0005
        # lr = sum(self.ds.aspect_max) * 0.01 / train_size # for whole batch iteration

        weight_decay = 0.0001
        momentum = 0.1
        if optim == 'sgd':
            self._optim = torch.optim.SGD(
                params=params,
                lr=lr,
                weight_decay=weight_decay,
                momentum=momentum
            )
        elif optim == 'adam': # do not use Adam, weights go to NaN with it for some reason
            self._optim = torch.optim.Adam(
                params=params,
                lr=lr,
                weight_decay=weight_decay,
                betas=(momentum, 0.999) # the first is params momentum, second RMSProp momentum (for now fix to 0.999 which is default Pytorch value)
            )
        else:
            assert False
        self._sched = torch.optim.lr_scheduler.StepLR(self._optim, step_size=1, gamma=0.95)


        self.train_ds, self.test_ds = random_split(self.ds, [train_size, test_size])

        train_loader = DataLoader(self.train_ds, batch_size=100, shuffle=True, collate_fn=lambda x: x) # do not use default collate function as it requires fixed-length input and raises this exception otherwise https://github.com/pytorch/pytorch/issues/42654
        self.train_epoch_history = []
        self.test_epoch_history = []
        self.train_nll_history = []
        self.test_nll_history = []
        try:
            for epoch in range(epoch_count):
                batch_count = len(train_loader)
                for i, batch in enumerate(tqdm(train_loader, desc=f'train epoch {epoch}/{epoch_count}')):
                    batch_nlls = []
                    for (rev_sents_ids, review_aspects_scores) in batch:
                        # for sent_ids in rev_sents_ids: sent_ids.to(dev)
                        rev_thetas = self.rev_words_thetas(rev_sents_ids)
                        rev_phis   = self.rev_words_phis(rev_sents_ids)

                        res_sents_scores = torch.stack(
                            [
                            rev_thetas[j].sum(dim=0) + torch.stack(tuple(rev_phis[j][a][:, review_aspects_scores[a] - 1].sum() for a in range(self.ds.aspect_count))) # 1 x aspect count
                            for j in range(len(rev_sents_ids))
                            ],
                        ) # sent count x aspect count

                        denoms = torch.logsumexp(res_sents_scores, dim=1)[:, None]
                        assert(denoms.shape == (len(rev_sents_ids), 1))

                        res_sents_scores = -res_sents_scores + denoms

                        sents_aspect_preds_max = torch.argmin(res_sents_scores, dim=1)
                        row_ind, col_ind = self._linear_assignement(costs=res_sents_scores.detach().cpu().numpy())
                        sents_aspect_preds_linear = sents_aspect_preds_max

                        # (most likely) aspect assignments (5)
                        sents_aspect_preds_linear[row_ind] = torch.from_numpy(col_ind).to(dev)
                        
                        # sentence likelihood (6)
                        batch_nlls.append(res_sents_scores.take_along_dim(sents_aspect_preds_linear[:, None], dim=1).sum())

                    batch_nll = torch.stack(batch_nlls).sum()
                    self._optim.zero_grad(set_to_none=True)
                    if 0 == i % 100:
                        self.train_nll_history.append(batch_nll.cpu().detach().item())
                        self.train_epoch_history.append(epoch + i / batch_count)
                    if 0 == i % 5000:
                        if i != 0:
                            self.dump_weights(dest_path=dest_path, filename=(f'-epoch-{epoch}-{epoch_count}-{int(i)}'))
                            self.word_clouds(dest_path=dest_path, filename=(f'cloud-epoch-{epoch}-{epoch_count}-{int(i)}.png'))
                            ic(self.rev_words_thetas([self.ds.vocab.lookup_indices(['taste', 'aroma', 'palete', 'antrunk'])]))
                        self.test_nll_history.append(self.test_nll())
                        self.test_epoch_history.append(epoch + i / batch_count)
                    batch_nll.backward()

                    self._optim.step()
                self._sched.step()

            self.dump_weights(dest_path=dest_path, filename=(f'weights-end'))
            self.word_clouds(dest_path=dest_path, filename=(f'cloud-end.png'))
            self.plot_nll_history(dest_path, (f'plot-end.png'))

        except KeyboardInterrupt:
            print('Interrupted.')

            self.dump_weights(dest_path=dest_path, filename=(f'weights-interrupt'))
            self.word_clouds(dest_path=dest_path, filename=(f'cloud-interrupt.png'))
            self.plot_nll_history(dest_path, (f'plot-interrupt.png'))

        except Exception as e:
            ic(len(batch_nlls))
            ic(denoms)
            ic(rev_sents_ids)

            self.dump_weights(dest_path=dest_path, filename=(f'weights-exception'))
            self.word_clouds(dest_path=dest_path, filename=(f'cloud-exception.png'))
            self.plot_nll_history(dest_path, (f'plot-exception.png'))

            ic(batch_nll)
            ic(denoms.size())
            ic((len(rev_sents_ids),))
            ic(denoms.size() == (len(rev_sents_ids),))
            raise e


In [None]:
%%capture output
# %%script python --no-raise-error
# model = Model(rb)
# model.train()
Model.experiments(root_dir='./saves/')

In [None]:
output.show()

In [None]:
# %store output > output_log
import pickle
with open('output', 'wb') as f:
    pickle.dump(output, f)

In [None]:
# danger
# model.ds.vocab.lookup_indices([0])

In [None]:
model = Model(rb) # single run
# model.train()

In [None]:
model.show_inference([
    'Good-looking bootle. Aroma is pretty astonishig. Sour and sweet palate profile. I like the taste very much. Overall I can recommend it to everyone.',
    'Tastes best from bottle. Not so heap as one could think. Nice hoppy smell. I had not supposed it will be sour though. Beautiful smooth head.',
])