# 1. Loading data

In [40]:
import pandas as pd

In [41]:
def load_data(split_name='train', columns=['text', 'label'], folder='data'):
    '''
        "split_name" may be set as 'train', 'valid' or 'test' to load the corresponding dataset.
        
        You may also specify the column names to load any columns in the .csv data file.
        Among many, "text" can be used as model input, and "label" column is the labels (sentiment). 
    '''
    try:
        print(f"select [{', '.join(columns)}] columns from the {split_name} split")
        df = pd.read_csv(f'{folder}/{split_name}.csv')
        df = df.loc[:,columns]
        print("Success")
        return df
    except:
        print(f"Failed loading specified columns... Returning all columns from the {split_name} split")
        df = pd.read_csv(f'{folder}/{split_name}.csv')
        return df

In [42]:
train_df = load_data('train', columns=['text', 'label'], folder='data')
valid_df = load_data('valid', columns=['text', 'label'], folder='data')
# the test set labels (the 'label' column) are unavailable! So the following code will instead return all columns
test_df = load_data('test_no_label', columns=['id', 'text'], folder='data')

select [text, label] columns from the train split
Success
select [text, label] columns from the valid split
Success
select [id, text] columns from the test_no_label split
Success


# 2. Data preprocessing functions

In [43]:
import nltk
from nltk.stem import PorterStemmer
from nltk.corpus import stopwords
stopwords = set (stopwords.words('english'))
porterStemmer = PorterStemmer()

def lower(s):
    """
    :param s: a string.
    return a string with lower characters
    Note that we allow the input to be nested string of a list.
    e.g.
    Input: 'Text mining is to identify useful information.'
    Output: 'text mining is to identify useful information.'
    """
    if isinstance(s, list):
        return [lower(t) for t in s]
    if isinstance(s, str):
        return s.lower()
    else:
        raise NotImplementedError("unknown datatype")


def tokenize(text):
    """
    :param text: a doc with multiple sentences, type: str
    return a word list, type: list
    e.g.
    Input: 'Text mining is to identify useful information.'
    Output: ['Text', 'mining', 'is', 'to', 'identify', 'useful', 'information', '.']
    """
    return nltk.word_tokenize(text)


def stem(tokens):
    """
    :param tokens: a list of tokens, type: list
    return a list of stemmed words, type: list
    e.g.
    Input: ['Text', 'mining', 'is', 'to', 'identify', 'useful', 'information', '.']
    Output: ['text', 'mine', 'is', 'to', 'identifi', 'use', 'inform', '.']
    """
    ### equivalent code
    # results = list()
    # for token in tokens:
    #     results.append(ps.stem(token))
    # return results

    return [porterStemmer.stem(token) for token in tokens]

def n_gram(tokens, n=1):
    """
    :param tokens: a list of tokens, type: list
    :param n: the corresponding n-gram, type: int
    return a list of n-gram tokens, type: list
    e.g.
    Input: ['text', 'mine', 'is', 'to', 'identifi', 'use', 'inform', '.'], 2
    Output: ['text mine', 'mine is', 'is to', 'to identifi', 'identifi use', 'use inform', 'inform .']
    """
    if n == 1:
        return tokens
    else:
        results = list()
        for i in range(len(tokens)-n+1):
            # tokens[i:i+n] will return a sublist from i th to i+n th (i+n th is not included)
            results.append(" ".join(tokens[i:i+n]))
        return results

def filter_stopwords(tokens):
    """
    :param tokens: a list of tokens, type: list
    return a list of filtered tokens, type: list
    e.g.
    Input: ['text', 'mine', 'is', 'to', 'identifi', 'use', 'inform', '.']
    Output: ['text', 'mine', 'identifi', 'use', 'inform', '.']
    """
    ### equivalent code
    # results = list()
    # for token in tokens:
    #     if token not in stopwords and not token.isnumeric():
    #         results.append(token)
    # return results

    return [token for token in tokens if token not in stopwords and not token.isnumeric()]


def get_onehot_vector(feats, feats_dict):
    """
    :param data: a list of features, type: list
    :param feats_dict: a dict from features to indices, type: dict
    return a feature vector,
    """
    # initialize the vector as all zeros
    vector = np.zeros(len(feats_dict), dtype=float)
    for f in feats:
        # get the feature index, return -1 if the feature is not existed
        f_idx = feats_dict.get(f, -1)
        if f_idx != -1:
            # set the corresponding element as 1
            vector[f_idx] = 1
    return vector

def biGram(tokens):
    return n_gram(tokens, 2)

# 3. Build our model

## A. CountVectorizer + Logistic regression

### Import library

In [44]:
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
from sklearn.linear_model import LogisticRegression

### Data preprocessing

In [45]:
train_df = load_data('train')
valid_df = load_data('valid')
x_train = train_df['text']
y_train = train_df['label']
x_valid = valid_df['text']
y_valid = valid_df['label']

select [text, label] columns from the train split
Success
select [text, label] columns from the valid split
Success


In [46]:
train_data_x = train_df['text'].map(tokenize).map(lower).map(filter_stopwords).map(stem)
train_data_y = train_df['label']
valid_data_x = valid_df['text'].map(tokenize).map(lower).map(filter_stopwords).map(stem)
valid_data_y = valid_df['label']

In [47]:
for i in range(len(train_data_x)):
    train_data_x[i] = ' '.join(train_data_x[i])
for i in range(len(valid_data_x)):
    valid_data_x[i] = ' '.join(valid_data_x[i])

In [48]:
train_data_x.to_csv("data/norm_train_data.csv")
valid_data_x.to_csv("data/norm_valid_data.csv")

In [49]:
norm_train_data = load_data("norm_train_data")['text']
norm_valid_data = load_data("norm_valid_data")['text']

select [text, label] columns from the norm_train_data split
Failed loading specified columns... Returning all columns from the norm_train_data split
select [text, label] columns from the norm_valid_data split
Failed loading specified columns... Returning all columns from the norm_valid_data split


### Model

In [50]:
countVectorizer = CountVectorizer(min_df=0.0, max_df=0.2, binary=False, ngram_range=(1, 3))
cV_train = countVectorizer.fit_transform(norm_train_data)
cV_valid = countVectorizer.transform(norm_valid_data)

In [51]:
lr = LogisticRegression()
lr.fit(cV_train, y_train)
lr

In [52]:
y_pred = lr.predict(cV_valid)
print(classification_report(y_valid, y_pred))
print("\n\n")
print(confusion_matrix(y_valid, y_pred))
print('accuracy', np.mean(y_valid == y_pred))

              precision    recall  f1-score   support

           1       0.57      0.52      0.54       295
           2       0.38      0.19      0.25       198
           3       0.47      0.55      0.51       508
           4       0.50      0.46      0.47       523
           5       0.61      0.71      0.65       476

    accuracy                           0.52      2000
   macro avg       0.50      0.48      0.49      2000
weighted avg       0.51      0.52      0.51      2000




[[153  28  73  20  21]
 [ 47  38  82  21  10]
 [ 41  27 279 116  45]
 [ 15   5 122 238 143]
 [ 11   3  39  85 338]]
accuracy 0.523


## B. LSTM

### Import library

In [67]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import tqdm
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
from torch.autograd import Variable

In [68]:
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")

### Data preprocessing

In [93]:
train_df = load_data('train')
valid_df = load_data('valid')
x_train = train_df['text'].map(tokenize).map(lower).map(filter_stopwords).map(stem)
y_train = train_df['label']
x_valid = valid_df['text'].map(tokenize).map(lower).map(filter_stopwords).map(stem)
y_valid = valid_df['label']

select [text, label] columns from the train split
Success
select [text, label] columns from the valid split
Success


In [94]:
for i in range(len(x_train)):
    x_train[i] = ' '.join(x_train[i])
for i in range(len(x_valid)):
    x_valid[i] = ' '.join(x_valid[i])
x_train.to_csv("data/norm_train_data.csv")
x_valid.to_csv("data/norm_valid_data.csv")

In [95]:
norm_train_data = load_data("norm_train_data")['text']
norm_valid_data = load_data("norm_valid_data")['text']

select [text, label] columns from the norm_train_data split
Failed loading specified columns... Returning all columns from the norm_train_data split
select [text, label] columns from the norm_valid_data split
Failed loading specified columns... Returning all columns from the norm_valid_data split


In [96]:
word2id = {}
for tokens in norm_train_data:
    for t in tokens:
        if not t in word2id:
            word2id[t] = len(word2id)
word2id['<pad>'] = len(word2id)

In [97]:
def texts_to_id_seq(texts, padding_length=50):
    records = []
    for tokens in texts:
        record = []
        for t in tokens:
            record.append(word2id.get(t, len(word2id)))
        if len(record) >= padding_length:
            records.append(record[:padding_length])
        else:
            records.append(record + [word2id['<pad>']] * (padding_length - len(record)))
    return records

In [98]:
train_seqs = texts_to_id_seq(norm_train_data)
valid_seqs = texts_to_id_seq(norm_valid_data)

In [99]:
X_train_tensors = Variable(torch.Tensor(train_seqs)).type(torch.LongTensor)
X_valid_tensors = Variable(torch.Tensor(valid_seqs)).type(torch.LongTensor)

y_train_tensors = Variable(torch.Tensor(y_train)).type(torch.LongTensor)
y_valid_tensors = Variable(torch.Tensor(y_valid)).type(torch.LongTensor)

In [100]:
class MyDataset(Dataset):
    
    def __init__(self, seq, y):
        assert len(seq) == len(y)
        self.seq = seq
        self.y = y-1
    
    def __getitem__(self, idx):
        return np.asarray(self.seq[idx]), self.y[idx]

    def __len__(self):
        return len(self.seq)

In [101]:
batch_size = 16

train_loader = DataLoader(MyDataset(X_train_tensors, y_train_tensors), batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(MyDataset(X_valid_tensors, y_valid_tensors), batch_size=batch_size)

### Model

In [102]:
class lstm_model(nn.Module):
    def __init__(self, num_classes, input_size, hidden_size, num_layers, seq_length):
        super(lstm_model, self).__init__()
        self.num_classes = num_classes #number of classes
        self.num_layers = num_layers #number of layers
        self.input_size = input_size #input size
        self.hidden_size = hidden_size #hidden state
        self.seq_length = seq_length #sequence length

        self.embedding = nn.Embedding(num_embeddings=len(word2id)+1, embedding_dim=input_size)
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size,
                          num_layers=num_layers, batch_first=True) #lstm
        self.fc_1 =  nn.Linear(hidden_size, 128) #fully connected 1
        self.max = nn.MaxPool1d(kernel_size=3,
                                stride=1)
        self.drop = nn.Dropout(0.5)
        self.fc = nn.Linear(128, num_classes) #fully connected last layer

        self.relu = nn.ReLU()
        self.softmax = nn.Softmax()
    
    def forward(self,x):
        h_0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size)) #hidden state
        c_0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size)) #internal state
        # Propagate input through LSTM
        embedd = self.embedding(x)
        embedd = self.drop(embedd)
        output, (hn, cn) = self.lstm(embedd, (h_0, c_0)) #lstm with input, hidden, and internal state
        # return self.softmax(self.fc(output[:,-1,:]))
        hn = hn.view(-1, self.hidden_size) #reshaping the data for Dense layer next
        out = self.relu(hn)
        # out = self.fc_1(out) #first Dense
        # out = self.relu(out) #relu
        out = self.fc(out) #Final Output
        return out

In [103]:
num_classes, input_size, hidden_size, num_layers, seq_length = 5, 64, 128, 1, 1
model = lstm_model(num_classes, input_size, hidden_size, num_layers, seq_length)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = torch.nn.CrossEntropyLoss()     

In [104]:
for e in range(1, 31):    
    print('epoch', e)
    model.train()
    total_acc = 0
    total_loss = 0
    total_count = 0
    with tqdm.tqdm(train_loader) as t:
        for x, y in t:
            # print(x.shape)
            optimizer.zero_grad()
            logits = model(x)
            loss = criterion(logits, y)
            loss.backward()
            total_acc += (logits.argmax(1) == y).sum().item()
            total_count += y.size(0)
            total_loss += loss.item()
            optimizer.step()
            t.set_postfix({'loss': total_loss/total_count, 'acc': total_acc/total_count})

    model.eval()
    y_pred = []
    y_true = []
    with tqdm.tqdm(valid_loader) as t:
        for x, y in t:
            logits = model(x)
            total_acc += (logits.argmax(1) == y).sum().item()
            total_count += len(y)
            y_pred += logits.argmax(1).tolist()
            y_true += y.tolist()
    print(classification_report(y_true, y_pred))
    print("\n\n")
    print(confusion_matrix(y_true, y_pred))
    print('accuracy', np.mean(y_valid == y_pred))

epoch 1


100%|██████████| 1125/1125 [00:28<00:00, 40.05it/s, loss=0.0977, acc=0.251]
100%|██████████| 125/125 [00:01<00:00, 116.15it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


              precision    recall  f1-score   support

           0       0.00      0.00      0.00       295
           1       0.00      0.00      0.00       198
           2       0.00      0.00      0.00       508
           3       0.28      0.27      0.27       523
           4       0.24      0.77      0.37       476

    accuracy                           0.25      2000
   macro avg       0.10      0.21      0.13      2000
weighted avg       0.13      0.25      0.16      2000




[[  0   0   0  80 215]
 [  0   0   0  47 151]
 [  0   0   0 125 383]
 [  0   0   0 139 384]
 [  0   0   0 110 366]]
accuracy 0.2545
epoch 2


100%|██████████| 1125/1125 [00:26<00:00, 42.85it/s, loss=0.0973, acc=0.257]
100%|██████████| 125/125 [00:01<00:00, 103.83it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


              precision    recall  f1-score   support

           0       0.00      0.00      0.00       295
           1       0.00      0.00      0.00       198
           2       0.26      0.37      0.30       508
           3       0.29      0.34      0.32       523
           4       0.29      0.39      0.33       476

    accuracy                           0.28      2000
   macro avg       0.17      0.22      0.19      2000
weighted avg       0.21      0.28      0.24      2000




[[  0   0 122  92  81]
 [  0   0  93  59  46]
 [  0   0 189 158 161]
 [  0   0 174 180 169]
 [  0   0 161 128 187]]
accuracy 0.21
epoch 3


100%|██████████| 1125/1125 [00:28<00:00, 39.38it/s, loss=0.0971, acc=0.263]
100%|██████████| 125/125 [00:01<00:00, 109.01it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


              precision    recall  f1-score   support

           0       0.00      0.00      0.00       295
           1       0.00      0.00      0.00       198
           2       0.29      0.15      0.20       508
           3       0.27      0.63      0.38       523
           4       0.30      0.33      0.31       476

    accuracy                           0.28      2000
   macro avg       0.17      0.22      0.18      2000
weighted avg       0.22      0.28      0.22      2000




[[  0   0  43 182  70]
 [  0   0  30 120  48]
 [  0   0  76 320 112]
 [  0   0  51 330 142]
 [  0   0  63 256 157]]
accuracy 0.246
epoch 4


100%|██████████| 1125/1125 [00:26<00:00, 42.44it/s, loss=0.0956, acc=0.298]
100%|██████████| 125/125 [00:01<00:00, 86.81it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


              precision    recall  f1-score   support

           0       0.33      0.02      0.04       295
           1       0.00      0.00      0.00       198
           2       0.29      0.78      0.42       508
           3       0.43      0.08      0.13       523
           4       0.40      0.42      0.41       476

    accuracy                           0.32      2000
   macro avg       0.29      0.26      0.20      2000
weighted avg       0.33      0.32      0.25      2000




[[  6   0 243   8  38]
 [  1   0 173   7  17]
 [  5   0 398  16  89]
 [  5   0 325  41 152]
 [  1   0 251  23 201]]
accuracy 0.1705
epoch 5


100%|██████████| 1125/1125 [00:27<00:00, 40.27it/s, loss=0.0933, acc=0.323]
100%|██████████| 125/125 [00:01<00:00, 105.50it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


              precision    recall  f1-score   support

           0       0.43      0.02      0.04       295
           1       0.00      0.00      0.00       198
           2       0.31      0.56      0.40       508
           3       0.31      0.24      0.27       523
           4       0.40      0.58      0.48       476

    accuracy                           0.34      2000
   macro avg       0.29      0.28      0.24      2000
weighted avg       0.32      0.34      0.29      2000




[[  6   0 194  44  51]
 [  0   0 137  35  26]
 [  2   0 282 105 119]
 [  3   0 184 125 211]
 [  3   0 108  89 276]]
accuracy 0.2265
epoch 6


100%|██████████| 1125/1125 [00:28<00:00, 39.14it/s, loss=0.0914, acc=0.348]
100%|██████████| 125/125 [00:01<00:00, 106.39it/s]


              precision    recall  f1-score   support

           0       0.41      0.16      0.23       295
           1       1.00      0.01      0.01       198
           2       0.32      0.59      0.41       508
           3       0.36      0.27      0.31       523
           4       0.44      0.52      0.48       476

    accuracy                           0.37      2000
   macro avg       0.51      0.31      0.29      2000
weighted avg       0.44      0.37      0.33      2000




[[ 46   0 185  31  33]
 [ 23   1 132  14  28]
 [ 23   0 301 105  79]
 [ 12   0 199 139 173]
 [  8   0 126  95 247]]
accuracy 0.205
epoch 7


100%|██████████| 1125/1125 [00:28<00:00, 39.55it/s, loss=0.0897, acc=0.363]
100%|██████████| 125/125 [00:01<00:00, 99.72it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


              precision    recall  f1-score   support

           0       0.38      0.22      0.28       295
           1       0.00      0.00      0.00       198
           2       0.31      0.43      0.36       508
           3       0.34      0.35      0.35       523
           4       0.44      0.54      0.48       476

    accuracy                           0.36      2000
   macro avg       0.29      0.31      0.29      2000
weighted avg       0.33      0.36      0.34      2000




[[ 64   0 147  48  36]
 [ 34   0 110  29  25]
 [ 35   0 219 164  90]
 [ 22   0 141 183 177]
 [ 13   0  96 111 256]]
accuracy 0.2255
epoch 8


100%|██████████| 1125/1125 [00:28<00:00, 39.35it/s, loss=0.0881, acc=0.379]
100%|██████████| 125/125 [00:01<00:00, 106.76it/s]


              precision    recall  f1-score   support

           0       0.37      0.39      0.38       295
           1       0.00      0.00      0.00       198
           2       0.32      0.36      0.34       508
           3       0.38      0.31      0.34       523
           4       0.42      0.62      0.50       476

    accuracy                           0.38      2000
   macro avg       0.30      0.33      0.31      2000
weighted avg       0.34      0.38      0.35      2000




[[114   1  99  30  51]
 [ 62   0  86  24  26]
 [ 77   0 181 133 117]
 [ 33   1 120 161 208]
 [ 23   0  85  72 296]]
accuracy 0.214
epoch 9


100%|██████████| 1125/1125 [00:27<00:00, 41.17it/s, loss=0.087, acc=0.388] 
100%|██████████| 125/125 [00:01<00:00, 109.23it/s]


              precision    recall  f1-score   support

           0       0.39      0.37      0.38       295
           1       0.50      0.01      0.01       198
           2       0.35      0.38      0.36       508
           3       0.38      0.30      0.33       523
           4       0.42      0.66      0.51       476

    accuracy                           0.39      2000
   macro avg       0.41      0.34      0.32      2000
weighted avg       0.40      0.39      0.36      2000




[[110   0  98  24  63]
 [ 59   1  83  25  30]
 [ 59   1 195 125 128]
 [ 31   0 122 155 215]
 [ 20   0  67  74 315]]
accuracy 0.2115
epoch 10


100%|██████████| 1125/1125 [00:27<00:00, 41.62it/s, loss=0.0857, acc=0.402]
100%|██████████| 125/125 [00:01<00:00, 107.23it/s]


              precision    recall  f1-score   support

           0       0.39      0.43      0.41       295
           1       0.00      0.00      0.00       198
           2       0.35      0.42      0.38       508
           3       0.42      0.25      0.31       523
           4       0.43      0.67      0.52       476

    accuracy                           0.40      2000
   macro avg       0.32      0.35      0.32      2000
weighted avg       0.36      0.40      0.36      2000




[[128   1  96  15  55]
 [ 67   0  80  21  30]
 [ 76   1 211  93 127]
 [ 35   0 142 131 215]
 [ 25   0  80  51 320]]
accuracy 0.1945
epoch 11


100%|██████████| 1125/1125 [00:26<00:00, 42.21it/s, loss=0.0844, acc=0.414]
100%|██████████| 125/125 [00:01<00:00, 108.52it/s]


              precision    recall  f1-score   support

           0       0.48      0.27      0.35       295
           1       0.50      0.01      0.01       198
           2       0.34      0.40      0.37       508
           3       0.37      0.41      0.39       523
           4       0.46      0.64      0.54       476

    accuracy                           0.40      2000
   macro avg       0.43      0.35      0.33      2000
weighted avg       0.41      0.40      0.38      2000




[[ 80   0 119  41  55]
 [ 35   1  97  39  26]
 [ 30   1 203 173 101]
 [ 13   0 126 212 172]
 [ 10   0  54 105 307]]
accuracy 0.221
epoch 12


100%|██████████| 1125/1125 [00:27<00:00, 41.11it/s, loss=0.0837, acc=0.42] 
100%|██████████| 125/125 [00:01<00:00, 108.07it/s]


              precision    recall  f1-score   support

           0       0.36      0.55      0.43       295
           1       0.15      0.01      0.02       198
           2       0.36      0.43      0.39       508
           3       0.39      0.27      0.32       523
           4       0.49      0.58      0.53       476

    accuracy                           0.40      2000
   macro avg       0.35      0.37      0.34      2000
weighted avg       0.38      0.40      0.37      2000




[[161   4  85  16  29]
 [ 89   2  71  16  20]
 [110   4 217 107  70]
 [ 50   2 161 143 167]
 [ 38   1  76  86 275]]
accuracy 0.1745
epoch 13


100%|██████████| 1125/1125 [00:28<00:00, 39.86it/s, loss=0.0826, acc=0.432]
100%|██████████| 125/125 [00:01<00:00, 109.85it/s]


              precision    recall  f1-score   support

           0       0.46      0.40      0.43       295
           1       0.43      0.05      0.08       198
           2       0.36      0.43      0.39       508
           3       0.40      0.32      0.35       523
           4       0.46      0.66      0.54       476

    accuracy                           0.41      2000
   macro avg       0.42      0.37      0.36      2000
weighted avg       0.41      0.41      0.39      2000




[[117   3 107  20  48]
 [ 51   9  92  18  28]
 [ 50   4 220 131 103]
 [ 20   4 139 166 194]
 [ 16   1  60  83 316]]
accuracy 0.21
epoch 14


100%|██████████| 1125/1125 [00:27<00:00, 41.47it/s, loss=0.0818, acc=0.434]
100%|██████████| 125/125 [00:01<00:00, 101.02it/s]


              precision    recall  f1-score   support

           0       0.43      0.41      0.42       295
           1       0.25      0.02      0.04       198
           2       0.38      0.42      0.40       508
           3       0.40      0.36      0.38       523
           4       0.46      0.64      0.54       476

    accuracy                           0.42      2000
   macro avg       0.38      0.37      0.35      2000
weighted avg       0.40      0.42      0.39      2000




[[122   4  90  29  50]
 [ 54   4  88  23  29]
 [ 61   4 212 135  96]
 [ 27   3 119 189 185]
 [ 19   1  56  93 307]]
accuracy 0.206
epoch 15


100%|██████████| 1125/1125 [00:27<00:00, 41.11it/s, loss=0.081, acc=0.446] 
100%|██████████| 125/125 [00:01<00:00, 108.95it/s]


              precision    recall  f1-score   support

           0       0.46      0.42      0.44       295
           1       0.25      0.04      0.06       198
           2       0.36      0.47      0.41       508
           3       0.39      0.37      0.38       523
           4       0.51      0.59      0.54       476

    accuracy                           0.42      2000
   macro avg       0.39      0.38      0.37      2000
weighted avg       0.41      0.42      0.40      2000




[[124   3 107  32  29]
 [ 59   7  87  24  21]
 [ 50  11 238 138  71]
 [ 23   5 150 194 151]
 [ 11   2  76 108 279]]
accuracy 0.1895
epoch 16


100%|██████████| 1125/1125 [00:27<00:00, 40.35it/s, loss=0.0803, acc=0.445]
100%|██████████| 125/125 [00:01<00:00, 101.50it/s]


              precision    recall  f1-score   support

           0       0.44      0.43      0.44       295
           1       0.24      0.07      0.10       198
           2       0.37      0.46      0.41       508
           3       0.42      0.28      0.34       523
           4       0.48      0.69      0.56       476

    accuracy                           0.42      2000
   macro avg       0.39      0.38      0.37      2000
weighted avg       0.41      0.42      0.40      2000




[[128  13 100  14  40]
 [ 60  13  90  13  22]
 [ 59  13 234  99 103]
 [ 24  10 153 145 191]
 [ 20   5  53  71 327]]
accuracy 0.1965
epoch 17


100%|██████████| 1125/1125 [00:27<00:00, 41.34it/s, loss=0.0797, acc=0.456]
100%|██████████| 125/125 [00:01<00:00, 109.05it/s]


              precision    recall  f1-score   support

           0       0.44      0.42      0.43       295
           1       0.34      0.05      0.09       198
           2       0.41      0.34      0.37       508
           3       0.37      0.35      0.36       523
           4       0.44      0.71      0.54       476

    accuracy                           0.41      2000
   macro avg       0.40      0.37      0.36      2000
weighted avg       0.40      0.41      0.39      2000




[[123   9  68  32  63]
 [ 60  10  58  37  33]
 [ 53   6 171 152 126]
 [ 27   3  91 184 218]
 [ 18   1  30  88 339]]
accuracy 0.2185
epoch 18


100%|██████████| 1125/1125 [00:26<00:00, 41.91it/s, loss=0.079, acc=0.457] 
100%|██████████| 125/125 [00:01<00:00, 110.22it/s]


              precision    recall  f1-score   support

           0       0.50      0.35      0.41       295
           1       0.32      0.06      0.10       198
           2       0.35      0.54      0.43       508
           3       0.37      0.33      0.35       523
           4       0.54      0.58      0.56       476

    accuracy                           0.42      2000
   macro avg       0.42      0.37      0.37      2000
weighted avg       0.42      0.42      0.40      2000




[[102   8 136  30  19]
 [ 43  12 110  19  14]
 [ 36   7 276 130  59]
 [ 15   7 183 173 145]
 [  9   4  79 110 274]]
accuracy 0.1965
epoch 19


100%|██████████| 1125/1125 [00:26<00:00, 43.13it/s, loss=0.0785, acc=0.464]
100%|██████████| 125/125 [00:01<00:00, 104.18it/s]


              precision    recall  f1-score   support

           0       0.43      0.48      0.45       295
           1       0.26      0.12      0.16       198
           2       0.39      0.43      0.41       508
           3       0.40      0.32      0.36       523
           4       0.50      0.62      0.56       476

    accuracy                           0.42      2000
   macro avg       0.40      0.40      0.39      2000
weighted avg       0.41      0.42      0.41      2000




[[142  24  72  22  35]
 [ 63  23  64  26  22]
 [ 65  26 220 122  75]
 [ 33  11 148 169 162]
 [ 29   4  59  88 296]]
accuracy 0.186
epoch 20


100%|██████████| 1125/1125 [00:33<00:00, 33.61it/s, loss=0.0776, acc=0.466]
100%|██████████| 125/125 [00:01<00:00, 73.82it/s]


              precision    recall  f1-score   support

           0       0.42      0.51      0.46       295
           1       0.23      0.04      0.06       198
           2       0.39      0.38      0.39       508
           3       0.40      0.37      0.38       523
           4       0.49      0.65      0.56       476

    accuracy                           0.43      2000
   macro avg       0.38      0.39      0.37      2000
weighted avg       0.40      0.43      0.41      2000




[[151   5  72  30  37]
 [ 68   7  74  28  21]
 [ 78  11 195 135  89]
 [ 36   5 116 191 175]
 [ 28   3  47  89 309]]
accuracy 0.1945
epoch 21


100%|██████████| 1125/1125 [00:39<00:00, 28.83it/s, loss=0.0773, acc=0.47] 
100%|██████████| 125/125 [00:01<00:00, 100.92it/s]


              precision    recall  f1-score   support

           0       0.42      0.51      0.46       295
           1       0.18      0.05      0.08       198
           2       0.40      0.40      0.40       508
           3       0.41      0.37      0.39       523
           4       0.49      0.62      0.55       476

    accuracy                           0.43      2000
   macro avg       0.38      0.39      0.38      2000
weighted avg       0.41      0.43      0.41      2000




[[151  11  73  22  38]
 [ 77  10  66  23  22]
 [ 73  19 205 133  78]
 [ 35   8 122 193 165]
 [ 22   8  52  99 295]]
accuracy 0.1875
epoch 22


100%|██████████| 1125/1125 [00:36<00:00, 31.08it/s, loss=0.0771, acc=0.477]
100%|██████████| 125/125 [00:01<00:00, 97.40it/s] 


              precision    recall  f1-score   support

           0       0.42      0.49      0.45       295
           1       0.32      0.06      0.09       198
           2       0.37      0.47      0.42       508
           3       0.40      0.30      0.34       523
           4       0.50      0.62      0.56       476

    accuracy                           0.42      2000
   macro avg       0.40      0.39      0.37      2000
weighted avg       0.41      0.42      0.40      2000




[[144   8  94  19  30]
 [ 69  11  82  17  19]
 [ 68  10 241 113  76]
 [ 38   4 155 156 170]
 [ 20   1  74  84 297]]
accuracy 0.1865
epoch 23


100%|██████████| 1125/1125 [00:29<00:00, 38.44it/s, loss=0.0762, acc=0.481]
100%|██████████| 125/125 [00:01<00:00, 108.12it/s]


              precision    recall  f1-score   support

           0       0.46      0.45      0.46       295
           1       0.25      0.12      0.16       198
           2       0.37      0.44      0.41       508
           3       0.38      0.30      0.34       523
           4       0.50      0.62      0.55       476

    accuracy                           0.42      2000
   macro avg       0.39      0.39      0.38      2000
weighted avg       0.41      0.42      0.41      2000




[[133  21  89  22  30]
 [ 55  24  79  22  18]
 [ 53  25 226 124  80]
 [ 25  20 151 159 168]
 [ 21   7  61  92 295]]
accuracy 0.196
epoch 24


100%|██████████| 1125/1125 [00:28<00:00, 39.86it/s, loss=0.0758, acc=0.483]
100%|██████████| 125/125 [00:01<00:00, 109.82it/s]


              precision    recall  f1-score   support

           0       0.45      0.39      0.42       295
           1       0.22      0.07      0.11       198
           2       0.37      0.42      0.39       508
           3       0.37      0.37      0.37       523
           4       0.48      0.60      0.54       476

    accuracy                           0.41      2000
   macro avg       0.38      0.37      0.37      2000
weighted avg       0.39      0.41      0.40      2000




[[116  19  92  35  33]
 [ 50  14  84  28  22]
 [ 50  19 213 143  83]
 [ 18  10 136 191 168]
 [ 21   3  51 115 286]]
accuracy 0.207
epoch 25


100%|██████████| 1125/1125 [00:26<00:00, 42.08it/s, loss=0.0752, acc=0.484]
100%|██████████| 125/125 [00:01<00:00, 110.46it/s]


              precision    recall  f1-score   support

           0       0.43      0.46      0.44       295
           1       0.25      0.11      0.15       198
           2       0.38      0.46      0.42       508
           3       0.40      0.31      0.35       523
           4       0.50      0.61      0.55       476

    accuracy                           0.42      2000
   macro avg       0.39      0.39      0.38      2000
weighted avg       0.41      0.42      0.41      2000




[[137  25  80  27  26]
 [ 61  22  78  21  16]
 [ 70  21 232 104  81]
 [ 31  12 154 164 162]
 [ 22   7  61  97 289]]
accuracy 0.1845
epoch 26


100%|██████████| 1125/1125 [00:26<00:00, 42.27it/s, loss=0.0747, acc=0.494]
100%|██████████| 125/125 [00:01<00:00, 103.60it/s]


              precision    recall  f1-score   support

           0       0.42      0.51      0.46       295
           1       0.35      0.12      0.17       198
           2       0.40      0.38      0.39       508
           3       0.38      0.37      0.37       523
           4       0.49      0.60      0.54       476

    accuracy                           0.42      2000
   macro avg       0.41      0.39      0.39      2000
weighted avg       0.41      0.42      0.41      2000




[[149  18  62  37  29]
 [ 71  23  56  30  18]
 [ 79  15 192 142  80]
 [ 33   7 124 193 166]
 [ 24   3  50 112 287]]
accuracy 0.191
epoch 27


100%|██████████| 1125/1125 [00:27<00:00, 40.98it/s, loss=0.0744, acc=0.496]
100%|██████████| 125/125 [00:01<00:00, 106.68it/s]


              precision    recall  f1-score   support

           0       0.41      0.45      0.43       295
           1       0.25      0.11      0.15       198
           2       0.40      0.36      0.38       508
           3       0.38      0.43      0.41       523
           4       0.51      0.57      0.54       476

    accuracy                           0.42      2000
   macro avg       0.39      0.38      0.38      2000
weighted avg       0.41      0.42      0.41      2000




[[134  24  61  41  35]
 [ 74  21  56  27  20]
 [ 72  23 182 165  66]
 [ 31  11 112 226 143]
 [ 19   4  48 132 273]]
accuracy 0.194
epoch 28


100%|██████████| 1125/1125 [00:27<00:00, 41.66it/s, loss=0.0738, acc=0.5]  
100%|██████████| 125/125 [00:01<00:00, 109.55it/s]


              precision    recall  f1-score   support

           0       0.46      0.39      0.42       295
           1       0.30      0.11      0.16       198
           2       0.37      0.49      0.42       508
           3       0.37      0.29      0.33       523
           4       0.49      0.62      0.55       476

    accuracy                           0.42      2000
   macro avg       0.40      0.38      0.38      2000
weighted avg       0.41      0.42      0.40      2000




[[114  19 104  21  37]
 [ 53  22  80  21  22]
 [ 46  19 248 119  76]
 [ 20   9 163 151 180]
 [ 14   4  69  92 297]]
accuracy 0.199
epoch 29


100%|██████████| 1125/1125 [00:27<00:00, 40.72it/s, loss=0.0736, acc=0.501]
100%|██████████| 125/125 [00:01<00:00, 109.75it/s]


              precision    recall  f1-score   support

           0       0.44      0.44      0.44       295
           1       0.29      0.12      0.17       198
           2       0.38      0.53      0.44       508
           3       0.40      0.27      0.33       523
           4       0.51      0.60      0.55       476

    accuracy                           0.42      2000
   macro avg       0.40      0.39      0.38      2000
weighted avg       0.42      0.42      0.41      2000




[[131  22  97  20  25]
 [ 58  23  79  18  20]
 [ 67  16 268  86  71]
 [ 28  11 182 143 159]
 [ 17   6  79  89 285]]
accuracy 0.173
epoch 30


100%|██████████| 1125/1125 [00:27<00:00, 40.76it/s, loss=0.0733, acc=0.505]
100%|██████████| 125/125 [00:01<00:00, 108.77it/s]

              precision    recall  f1-score   support

           0       0.47      0.35      0.40       295
           1       0.33      0.10      0.16       198
           2       0.38      0.47      0.42       508
           3       0.37      0.35      0.36       523
           4       0.49      0.61      0.55       476

    accuracy                           0.42      2000
   macro avg       0.41      0.38      0.38      2000
weighted avg       0.41      0.42      0.41      2000




[[104  20 101  32  38]
 [ 48  20  77  29  24]
 [ 41  11 240 140  76]
 [ 16   6 154 185 162]
 [ 12   3  59 110 292]]
accuracy 0.1995



