# Amazon Review Classification

## Dataset Info

Amazon Review Dataset!  The dataset has been binarized into 2 categories based on value of the 5 star review. 

## Model Info

Build Your Own Model!

In [37]:
from argparse import Namespace
from collections import Counter
import json
import os
import re
os.environ['OMP_NUM_THREADS'] = '4' 

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm_notebook

from vocabulary import Vocabulary

%matplotlib inline

plt.style.use('fivethirtyeight')
plt.rcParams['figure.figsize'] = (14, 6)

START_TOKEN = "^"
END_TOKEN = "_"

# Dataset

### Dataset Utilities

In [7]:
def count_tokens(x_data_list):
    """Count the tokens in the data list
    
    Args:
        x_data_list (list(list(str))): a list of lists, each sublist is a list of string tokens. 
            In other words, a list of the data points where the data points have been tokenized.
    Returns:
        dict: a mapping from tokens to their counts 
    
    """
    # alternatively
    # return Counter([token for x_data in x_data_list for token in x_data])
    counter = Counter()
    for x_data in x_data_list:
        for token in x_data:
            counter[token] += 1
    return counter

def add_splits(df, target_y_column, split_proportions=(0.7, 0.15, 0.15), seed=0):
    """Add 'train', 'val', and 'test' splits to the dataset
    
    Args:
        df (pd.DataFrame): the data frame to assign splits to
        target_y_column (str): the name of the label column; in order to
            preserve the class distribution between splits, the label column
            is used to group the datapoints and splits are assigned within these groups.
        split_proportions (tuple(float, float, float)): three floats which represent the
            proportion in 'train', 'val, 'and 'test'. Must sum to 1. 
        seed (int): the random seed for making the shuffling deterministic. If the dataset and seed
            are kept the same, the split assignment is deterministic. 
    Returns:
        pd.DataFrame: the input dataframe with a new column for split assignments; note: row order
            will have changed.
            
    """
    df_by_label = {label: [] for label in df[target_y_column].unique()}
    for _, row in df.iterrows():
        df_by_label[row[target_y_column]].append(row.to_dict())
    
    np.random.seed(seed)
    
    assert sum(split_proportions) == 1, "`split_proportions` should sum to 1"
    train_p, val_p, test_p = split_proportions
    
    out_df = []
    # to ensure consistent behavior, lexicographically sort the dictionary
    for _, data_points in sorted(df_by_label.items()):
        np.random.shuffle(data_points)
        n_total = len(data_points)
        n_train = int(train_p * n_total)
        n_val = int(val_p * n_total)
        
        for data_point in data_points[:n_train]:
            data_point['split'] = 'train'
            
        for data_point in data_points[n_train:n_train+n_val]:
            data_point['split'] = 'val'
            
        for data_point in data_points[n_train+n_val:]:
            data_point['split'] = 'test'
        
        out_df.extend(data_points)
    
    return pd.DataFrame(out_df)

### Supervised Text Vectorizer

In [8]:
class SupervisedTextVectorizer:
    """A composite data structure that uses Vocabularies to map text and its labels to integers
    
    Attributes:
        token_vocab (Vocabulary): the vocabulary managing the mapping between text tokens and 
            the unique indices that represent them
        label_voab (Vocabulary): the vocabulary managing the mapping between labels and the
            unique indices that represent them.
        max_seq_length (int): the length of the longest sequence (including start or end tokens
            that will be prepended or appended).
    """
    def __init__(self, token_vocab, label_vocab, max_seq_length):
        """Initialize the SupervisedTextVectorizer
        
        Args:
            token_vocab (Vocabulary): the vocabulary managing the mapping between text tokens and 
                the unique indices that represent them
            label_voab (Vocabulary): the vocabulary managing the mapping between labels and the
                unique indices that represent them.
            max_seq_length (int): the length of the longest sequence (including start or end tokens
                that will be prepended or appended).
        """
        self.token_vocab = token_vocab
        self.label_vocab = label_vocab
        self.max_seq_length = max_seq_length
        
    def _wrap_with_start_end(self, x_data):
        """Prepend the start token and append the end token.
        
        Args:
            x_data (list(str)): the list of string tokens in the data point
        Returns:
            list(str): the list of string tokens with start token prepended and end token appended
        """
        return [self.token_vocab.start_token] + x_data + [self.token_vocab.end_token]
    
    def vectorize(self, x_data, y_label):
        """Convert the data point and its label into their integer form
        
        Args:
            x_data (list(str)): the list of string tokens in the data point
            y_label (str,int): the label associated with the data point
        Returns:
            numpy.ndarray, int: x_data in vector form, padded to the max_seq_length; and 
                the label mapped to the integer that represents it
        """
        x_data = self._wrap_with_start_end(x_data)
        x_vector = np.zeros(self.max_seq_length).astype(np.int64)
        x_data_indices = [self.token_vocab[token] for token in x_data]
        x_vector[:len(x_data_indices)] = x_data_indices
        y_index = self.label_vocab[y_label]
        return x_vector, y_index
    
    def transform(self, x_data_list, y_label_list):
        """Transform a dataset by vectorizing each datapoint
        
        Args: 
            x_data_list (list(list(str))): a list of lists, each sublist contains string tokens
            y_label_list (list(str,int)): a list of either strings or integers. the y label can come
                as strings or integers, but they are remapped with the label_vocab to a unique integer
        Returns:
            np.ndarray(matrix), np.ndarray(vector): the vectorized x (matrix) and vectorized y (vector) 
        """
        x_matrix = []
        y_vector = []
        for x_data, y_label in zip(x_data_list, y_label_list):
            x_vector, y_index = self.vectorize(x_data, y_label)
            x_matrix.append(x_vector)
            y_vector.append(y_index)
        
        return np.stack(x_matrix), np.stack(y_vector)
    
    @classmethod
    def from_df(cls, df, target_x_column, target_y_column, token_count_cutoff=0):
        """Instantiate the SupervisedTextVectorizer from a standardized dataframe
        
        Standardized DataFrame has a special meaning:
            there is a column that has been tokenized into a list of strings
        
        Args:
            df (pd.DataFrame): the dataset with a tokenized text column and a label column
            target_x_column (str): the name of the tokenized text column
            target_y_column (str): the name of the label column
            token_count_cutoff (int): [default=0] the minimum token frequency to add to the
                token_vocab.  Any tokens that are less frequent will not be added.
        Returns:
            SupervisedTextVectorizer: the instantiated vectorizer
        """
        # get the x data (the observations)
        target_x_list = df[target_x_column].tolist()
        # compute max sequence length, add 2 for the start, end tokens
        max_seq_length = max(map(len, target_x_list)) + 2 
        
        # populate token vocab        
        token_vocab = Vocabulary(use_unks=False,
                                 use_mask=True,
                                 use_start_end=True,
                                 start_token=START_TOKEN,
                                 end_token=END_TOKEN)
        counts = count_tokens(target_x_list)
        # sort counts in reverse order
        for token, count in sorted(counts.items(), key=lambda x: x[1], reverse=True):
            if count < token_count_cutoff:
                break
            token_vocab.add(token)

        # populate label vocab
        label_vocab = Vocabulary(use_unks=False, use_start_end=False, use_mask=False)
        # add the sorted unique labels 
        label_vocab.add_many(sorted(df[target_y_column].unique()))
        
        return cls(token_vocab, label_vocab, max_seq_length)
    
    def save(self, filename):
        """Save the vectorizer using json to the file specified
        
        Args:
            filename (str): the output file
        """
        vec_dict = {"token_vocab": self.token_vocab.get_serializable_contents(),
                    "label_vocab": self.label_vocab.get_serializable_contents(),
                    'max_seq_length': self.max_seq_length}

        with open(filename, "wb") as fp:
            json.dump(vec_dict, fp)
        
    @classmethod
    def load(cls, filename):
        """Load the vectorizer from the json file it was saved to
        
        Args:
            filename (str): the file into which the vectorizer was saved.
        Returns:
            SupervisedTextVectorizer: the instantiated vectorizer
        """
        with open(filename, "rb") as fp:
            contents = json.load(fp)

        contents["token_vocab"] = Vocabulary.deserialize_from_contents(contents["token_vocab"])
        contents["label_vocab"] = Vocabulary.deserialize_from_contents(contents["label_vocab"])
        return cls(**contents)

### Supervised Text Dataset

In [10]:
class SupervisedTextDataset(Dataset):
    """
    Attributes:
        vectorizer (SupervisedTextVectorizer): an instantiated vectorizer
        active_split (str): the string name of the active split
        
        # internal use
        _split_df (dict): a mapping from split name to partitioned DataFrame
        _vectorized (dict): a mapping from split to an x data matrix and y vector
        _active_df (pd.DataFrame): the DataFrame corresponding to the split
        _active_x (np.ndarray): a matrix of the vectorized text data
        _active_y (np.ndarray): a vector of the vectorized labels
    """
    def __init__(self, df, vectorizer, target_x_column, target_y_column):
        """Initialize the SupervisedTextDataset
        
        Args:
            df (pd.DataFrame): the dataset with a text and label column
            vectorizer (SupervisedTextVectorizer): an instantiated vectorizer
            target_x_column (str): the column containing the tokenized text
            target_y_column (str): the column containing the label
        """
        self._split_df = {
            'train': df[df.split=='train'],
            'val': df[df.split=='val'],
            'test': df[df.split=='test']
        }
        
        self._vectorized = {}
        for split_name, split_df in self._split_df.items():
            self._vectorized[split_name] = \
                vectorizer.transform(x_data_list=split_df[target_x_column].tolist(), 
                                     y_label_list=split_df[target_y_column].tolist())
        self.vectorizer = vectorizer
        self.active_split = None
        self._active_df = None
        self._active_x = None
        self._active_y = None
        
        self.set_split("train")
        
    def set_split(self, split_name):
        """Set the active split
        
        Args:
            split_name (str): the name of the split to make active; should
                be one of 'train', 'val', or 'test'
        """
        self.active_split = split_name
        self._active_x, self._active_y = self._vectorized[split_name]
        self._active_df = self._split_df[split_name]
    
    def __getitem__(self, index):
        """Return the data point corresponding to the index
        
        Args:
            index (int): an int between 0 and len(self._active_x)
        Returns:
            dict: the data for this data point. Has the following form:
                {"x_data": the vectorized text data point, 
                 "y_target": the index of the label for this data point, 
                 "x_lengths": method: the number of nonzeros in the vector,
                 "data_index": the provided index for bookkeeping}
        """
        return {
            "x_data": self._active_x[index],
            "y_target": self._active_y[index],
            "x_lengths": len(self._active_x[index].nonzero()[0]),
            "data_index": index
        }
    
    def __len__(self):
        """The length of the active dataset
        
        Returns:
            int: len(self._active_x)
        """
        return self._active_x.shape[0]

### Dataset Loading Function

In [17]:
df = pd.read_csv("../data/amazon_train_small.csv", header=None, names=['label', 'title', 'review'])
print(len(df))
df.head()

100000


Unnamed: 0,label,title,review
0,2,Right on the money,We are using the this book to get 100+ certifi...
1,2,Serves its Purpose!,Couldn't go without it. My 3 1/2 year still we...
2,2,Trailer Park Bwoys!!!,we get to see it on paramount in ol' LND UK an...
3,1,buyer beware,There are companies selling Bosch knock-offs o...
4,2,Great for those cold winters,If you are looking to keep your water liquifie...


In [1]:
def character_tokenizer(input_string):
    """Tokenized a string a list of its characters
    
    Args:
        input_string (str): the character string to tokenize
    Returns:
        list: a list of characters
    """
    return list(input_string.lower())


def simple_word_tokenizer(text):
    """Tokenize a sentence string into a list of words
    
    Args:
        text (str): the sentence string to tokenize
    """
    text = re.sub(r"([.,!?])", r" \1 ", text.lower())
    return [tok for tok in text.split(" ") if len(tok) > 0]

def load_amazon_review_dataset(dataset_csv, tokenizer_func):
    """Load the amazon review dataset 
    
    Args:
        dataset_csv (str): the location of the dataset
        tokenizer_func (function): the tokenizing function to turn each datapoint into 
            its tokenized form
    """
    df = add_splits(pd.read_csv(dataset_csv, header=None, 
                                names=['label', 'title', 'review']), 
                    target_y_column='label')
    df['tokenized'] = df.review.apply(tokenizer_func)
    df['label'] = df['label'].apply(lambda label_int: {1: 'negative', 2: 'positive'}[label_int])
    vectorizer = SupervisedTextVectorizer.from_df(df, 
                                                  target_x_column='tokenized', 
                                                  target_y_column='label')
    dataset = SupervisedTextDataset(df=df, 
                                    vectorizer=vectorizer, 
                                    target_x_column='tokenized', 
                                    target_y_column='label')
    
    return dataset

### Verify it loads

In [40]:
dataset = load_amazon_review_dataset("../data/amazon_reviews_small.csv", 
                                     simple_word_tokenizer)
dataset[0]

{'x_data': array([    1,    12,    13,    20,  1054,     5,    12,    13,  3111,
         1054,    38,   126,   368,     3,   236,   527,  1924, 39957,
        72477,   302,    65,   368,  1189,    28,     8,  1234,     9,
           12,   953,     3,   192,   119,  1838,   970,    45,    67,
           80,     6,   555,   323,    16,    27, 29828,    23,    92,
        50265, 26826,     6,  5085,     9,    55,  9942,     6,  2293,
           71,   192,     5,  1000,     4,   136,  1309,    70,     9,
           72,    10,     3,   236,    63,  6163,     5,   236,    63,
          240,    11,    69,     5,    19,    51,  1019,    58,  1391,
            6, 13078,     5,     6,    33,    19,   523,     5,    19,
           27,  1350,     4, 29829,    16,   557,    64,  1284,   172,
           18,    18,    58,    13,    20,    29,    40,   167,     9,
          155,  2067,    12,    69,     5,  1607,     3,     3,    58,
           13,    46,   577,     5,    33,    19,   122,    10,    

# Model

Fill this part in :)

### Model Utilities

### Prototyping

# Training

### Args

In [45]:
args = Namespace(
    # dataset
    dataset_csv="../data/amazon_reviews_small.csv",
    # model hyper parameters
    # ADD OTHER MODEL HYPER PARAMETERS HERE
    num_embeddings=-1,
    num_classes=-1,
    # training options
    batch_size = 128,
    cuda=False,
    learning_rate=0.001,
    num_epochs=100,
    patience_threshold=3,
)


# Check CUDA
if not torch.cuda.is_available():
    args.cuda = False

print("Using CUDA: {}".format(args.cuda))

args.device = torch.device("cuda" if args.cuda else "cpu")
args.device

Using CUDA: False


device(type='cpu')

### Training Utiltiies

In [43]:
def compute_accuracy(y_pred, y_target):
    """Compute the accuracy between a matrix of predictions and a vector of label indices
    
    Args:
        y_pred (torch.FloatTensor): [shape=(batch_size, num_classes)]
            The matrix of predictions
        y_true (torch.FloatTensor): [shape=(batch_size,)]
            The vector of label indices
    """
    y_pred_indices = y_pred.argmax(dim=1)
    n_correct = torch.eq(y_pred_indices, y_target).sum().item()
    return n_correct / len(y_pred_indices) * 100


def generate_batches(dataset, batch_size, shuffle=True,
                     drop_last=True, device="cpu", dataloader_kwargs=None): 
    """Generate batches from a dataset
    
    Args:
        dataset (torch.utils.data.Dataset): the instantiated dataset
        batch_size (int): the size of the batches
        shuffle (bool): [default=True] batches are formed from shuffled indices
        drop_last (bool): [default=True] don't return the final batch if it's smaller
            than the specified batch size
        device (str): [default="cpu"] the device to move the tensors to
        dataloader_kwargs (dict or None): [default=None] Any additional arguments to the
            DataLoader can be specified
    Yields:
        dict: a dictionary mapping from tensor name to tensor object where the first
            dimension of tensor object is the batch dimension
    Note: 
        This function is mostly an iterator for the DataLoader, but has the added
        feature that it moves the tensors to a target device. 
    """
    dataloader_kwargs = dataloader_kwargs or {}
    
    dataloader = DataLoader(dataset=dataset, batch_size=batch_size,
                            shuffle=shuffle, drop_last=drop_last, **dataloader_kwargs)

    for data_dict in dataloader:
        out_data_dict = {}
        for name, tensor in data_dict.items():
            out_data_dict[name] = data_dict[name].to(device)
        yield out_data_dict


class TrainState:
    """A data structure for managing training state operations.
    
    The TrainState will monitor validation loss and everytime a new best loss
        (lower is better) is observed, a couple things happen:
        
        1. The model is checkpointed
        2. Patience is reset
    
    Attributes:
        model (torch.nn.Module): the model being trained and will be
            checkpointed during training.
        dataset (SupervisedTextDataset, TextSequenceDataset): the dataset 
            which is being iterate during training; must have the `active_split`
            attribute. 
        log_dir (str): the directory to output the checkpointed model 
        patience (int): the number of epochs since a new best loss was observed
        
        # Internal Use
        _full_model_path (str): `log_dir/model_state_file`
        _split (str): the active split
        _best_loss (float): the best observed loss
    """
    def __init__(self, model, dataset, log_dir, model_state_file="model.pth"):
        """Initialize the TrainState
        
        Args:
            model (torch.nn.Module): the model to be checkpointed during training
            dataset (SupervisedTextDataset, TextSequenceDataset): the dataset 
                which is being iterate during training; must have the `active_split`
                attribute. 
            log_dir (str): the directory to output the checkpointed model 
            model_state_file (str): the name of the checkpoint model
        """
        self.model = model
        self.dataset = dataset
        self._full_model_path = os.path.join(log_dir, model_state_file)
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
        self.log_dir = log_dir
        
        self._metrics_by_split = {
            'train': {}, 
            'val': {}, 
            'test': {}
        }
        
        self._split = 'train'
        self._best_loss = 10**10
        self.patience = 0
        
    def _init_metric(self, split, metric_name):
        """Initialize a metric to the specified split
        
        A dictionary is created in `self._metrics_by_split` with
            the keys 'running', 'count', and 'history'. 
        
        Args:
            split (str): the target split to record the metric
            metric_name (str): the name of the metric
        """
        self._metrics_by_split[split][metric_name] = {
            'running': 0.,
            'count': 0,
            'history': []
        }
        
    def _update_metric(self, metric_name, metric_value):
        """Update a metric with an observed value
        
        Specifically, the running average is updated.
        
        Args:
            metric_name (str): the name of the metric
            metric_value (float): the observed value of the metric
        """
        if metric_name not in self._metrics_by_split[self._split]:
            self._init_metric(self._split, metric_name)
        metric = self._metrics_by_split[self._split][metric_name]
        metric['count'] += 1
        metric['running'] += (metric_value - metric['running']) / metric['count']
        
    def set_split(self, split):
        """Set the dataset split
        
        Args:
            split (str): the target split to set
        """
        self._split = split
        
    def get_history(self, split, metric_name):
        """Get the history of values for any metric in any split
        
        Args:
            split (str): the target split
            metric_name (str): the target metric
            
        Returns:
            list(float): the running average of each epoch for `metric_name` in `split` 
        """
        return self._metrics_by_split[split][metric_name]['history']
    
    def get_value_of(self, split, metric_name):
        """Retrieve the running average of any metric in any split
        
        Args:
            split (str): the target split
            metric_name (str): the target metric
            
        Returns:
            float: the running average for `metric_name` in `split`
        """
        return self._metrics_by_split[split][metric_name]['running']
        
    def log_metrics(self, **metrics):
        """Log some values for some metrics
        
        Args:
            metrics (kwargs): pass keyword args with the form `metric_name=metric_value`
                to log the metric values into the attribute `_metrics_by_split`.
        """
        self._split = self.dataset.active_split
        for metric_name, metric_value in metrics.items():
            self._update_metric(metric_name, metric_value)
            
    def log_epoch_end(self):
        """Log the end of the epoch. 
        
        Some key functions happen at the end of the epoch:
            - for each metric in each split running averages, counts, 
              and history are updated
            - the model is checkpointed if a new best value is observed
            - patience is incremented if a new best value is not observed
        """
        for split_dict in self._metrics_by_split.values():
            for metric_dict in split_dict.values():
                metric_dict['history'].append(metric_dict['running'])
                metric_dict['running'] = 0.0
                metric_dict['count'] = 0
                
        if 'loss' in self._metrics_by_split['val']:
            val_loss = self._metrics_by_split['val']['loss']['history'][-1]
            if val_loss < self._best_loss:
                self._best_loss = val_loss
                self.save_model()
                self.patience = 0
            else:
                self.patience += 1
    
    def save_model(self):
        """ Save `model` to `log_dir/model_state_file` """
        torch.save(self.model.state_dict(), self._full_model_path)
    
    def reload_best(self):
        """ reload `log_dir/model_state_file` to `model` """
        if os.path.exists(self._full_model_path):
            self.model.load_state_dict(torch.load(self._full_model_path))

### Instantiation

In [46]:
dataset = load_amazon_review_dataset(args.dataset_csv, 
                                     tokenizer_func=simple_word_tokenizer)

args.num_embeddings = len(dataset.vectorizer.token_vocab)
args.num_classes = len(dataset.vectorizer.label_vocab)

# model = ....

### Training Routine

In [None]:
model = model.to(args.device)

train_state = TrainState(model, 'model.pth', './logs/amazon_reviews/v1')

optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)

loss_func = nn.CrossEntropyLoss()


# progress bars

epoch_bar = tqdm_notebook(desc='epochs', total=args.num_epochs, position=1)

dataset.set_split("train")
train_bar = tqdm_notebook(desc='training', total=len(dataset)//args.batch_size)

dataset.set_split("val")
val_bar = tqdm_notebook(desc='validation', total=len(dataset)//args.batch_size)
        

try:
    for _ in range(args.num_epochs):
        model.train()
        dataset.set_split("train")
        # TODO: deprecate in favor of single source of truth
        train_state.set_split("train")
        
        for batch in generate_batches(dataset, batch_size=args.batch_size, device=args.device):
            # Step 1: clear the gradients 
            optimizer.zero_grad()
            
            # Step 2: compute the outputs
            y_prediction = model(batch['x_data'])

            # Step 3: compute the loss
            loss = loss_func(y_prediction, batch['y_target'])
            
            # Step 4: propagate the gradients
            loss.backward() 
            
            # Step 5: update the model weights
            optimizer.step()
            
            # Auxillary: logging
            train_state.log_metrics(loss=loss.item(), 
                                    accuracy=compute_accuracy(y_prediction, batch['y_target']))
            
            train_bar.set_postfix(loss=train_state.get_value_of(split="train", metric_name="loss"),
                                  acc=train_state.get_value_of(split="train", metric_name="accuracy"))
            train_bar.update()
            
        # loop over test dataset
        
        model.eval()
        dataset.set_split("val")
        train_state.set_split("val")
        
        for batch in generate_batches(dataset, batch_size=args.batch_size, device=args.device):
            # Step 1: compute the outputs
            y_prediction = model(batch['x_data'])

            # Step 2: compute the loss
            loss = loss_func(y_prediction, batch['y_target'])
            
            # Auxillary: logging
            train_state.log_metrics(loss=loss.item(), 
                                    accuracy=compute_accuracy(y_prediction, batch['y_target']))
            
            val_bar.set_postfix(loss=train_state.get_value_of(split="val", metric_name="loss"),
                                  acc=train_state.get_value_of(split="val", metric_name="accuracy"))
            val_bar.update()

        
        epoch_bar.set_postfix(train_loss=train_state.get_value_of(split="train", 
                                                                  metric_name="loss"), 
                              train_accuracy=train_state.get_value_of(split="train", 
                                                                      metric_name="accuracy"),
                              val_loss=train_state.get_value_of(split="val", 
                                                                metric_name="loss"), 
                              val_accuracy=train_state.get_value_of(split="val", 
                                                                    metric_name="accuracy"),
                              patience=train_state.patience)
        epoch_bar.update()
        train_state.log_epoch_end()
        train_bar.n = 0
        val_bar.n = 0
        
        if train_state.patience > args.patience_threshold:
            break

    train_state.reload_best()
    model.eval()
    dataset.set_split("test")
    test_bar = tqdm_notebook(desc='test', total=len(dataset)//args.batch_size)

    for batch in generate_batches(dataset, batch_size=args.batch_size, device=args.device):
        # Step 1: compute the outputs
        y_prediction = model(batch['x_data'])

        # Step 2: compute the loss
        loss = loss_func(y_prediction, batch['y_target'])

        # Auxillary: logging
        train_state.log_metrics(loss=loss.item(), 
                                accuracy=compute_accuracy(y_prediction, batch['y_target']))

        test_bar.set_postfix(loss=train_state.get_value_of(split="test", metric_name="loss"),
                             acc=train_state.get_value_of(split="test", metric_name="accuracy"))
        test_bar.update()            
            
except KeyboardInterrupt:
    print("...")