In [1]:
CUDA_VISIBLE_DEVICES=1

In [2]:
import torch
import torch.nn as nn
import os
import logging
import numpy as np
import random
from tqdm import tqdm
import time
import pandas as pd

from transformers import LongformerModel, AutoTokenizer, LongformerForSequenceClassification, LongformerForMultipleChoice
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, accuracy_score, classification_report

logging.basicConfig(filename=f'./logs/train_{time.asctime().replace(" ","_")}.log', filemode='w', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Create a logger object
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# Create a stream handler to print log messages to the console
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)

torch.manual_seed(40)
np.random.seed(40)
random.seed(40)
torch.cuda.manual_seed(40)
torch.backends.cudnn.deterministic = True

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Define the path to the CSV file
train_csv_file = "/data1/debajyoti/colie/train.csv"
val_csv_file = "/data1/debajyoti/colie/valid.csv"

# Read the CSV file
train_labels = pd.read_csv(train_csv_file)
val_labels = pd.read_csv(val_csv_file)
val_labels

Unnamed: 0,BOOK_id,Epoch
0,31873_1.txt,Viktorian
1,31873_2.txt,Viktorian
2,31873_3.txt,Viktorian
3,31873_4.txt,Viktorian
4,31873_5.txt,Viktorian
...,...,...
36252,36919_48.txt,Modernism
36253,36919_49.txt,Modernism
36254,36919_50.txt,Modernism
36255,36919_51.txt,Modernism


In [4]:
train_labels.BOOK_id[0]

'27993_1.txt'

In [5]:
# Define the path to the train folder
train_folder = "/data1/debajyoti/colie/train/train/"
# Define the path to the validation folder
val_folder = "/data1/debajyoti/colie/valid/valid/"



def create_df(folder, label):
    # Initialize empty lists to store the data
    text_data = []
    labels = []
    for index in label.index:
        # filename = df_labels.BOOK_id[index]
        # print(filename)
        # print(df_labels['BOOK_id'][index], df_labels['Epoch'][index])
        file_name = label['BOOK_id'][index]  # Assuming 'File Name' is the column name for the file names in the CSV

        # Construct the file path
        file_path = os.path.join(folder, file_name)

        # Read the text from the file
        with open(file_path, 'r', encoding='ISO-8859-1') as file:
            text = file.read()

        # Append the text and label to the respective lists
        text_data.append(text)
        labels.append(label['Epoch'][index].strip())  # Assuming 'Label' is the column name for the labels in the CSV
        # break
    return text_data, labels

train_data, train_label = create_df(train_folder, train_labels)
val_data, val_label = create_df(val_folder, val_labels)

# Create a dataframe from the lists
train = pd.DataFrame({'text': train_data, 'label': train_label})
val = pd.DataFrame({'text': val_data, 'label': val_label})
print(train.head(), val.head())
print(train.shape, val.shape)

                                                text      label
0  rifle; Ivan's was a double-barrelled shot-gun ...  Viktorian
1  upon the track of the bear. After following it...  Viktorian
2  to pull him out with their hands--even had the...  Viktorian
3  a slight sparkle of scientific conceit, "this ...  Viktorian
4  bears with a white ring round their necks? Yes...  Viktorian                                                 text      label
0  kind good morning, and returned her hearty emb...  Viktorian
1  sky, and of the moon, which clothed the old pi...  Viktorian
2  left Rome for Augsburg, my mind being much exc...  Viktorian
3  thoughts some of the old melodies he knew by h...  Viktorian
4  "But," said Henry, "is it not possible that th...  Viktorian
(546210, 2) (36257, 2)


In [6]:
label_dic = {'Romanticism':0,
            'Viktorian':1,
            'Modernism':2,
            'PostModernism':3,
            'OurDays':4}
train['label'] = train['label'].map(label_dic)
val['label'] = val['label'].map(label_dic)

In [7]:
# Length of text
def length (txt):
    length = len(txt.split())
    return length

txt_length = train['text'].apply(lambda x: length(x))
print(txt_length.sort_values(ascending = False))

483268    1128
483267    1068
521384    1065
483265    1034
81542     1020
          ... 
470405       1
130188       1
217335       1
351867       1
368135       1
Name: text, Length: 546210, dtype: int64


In [8]:
val['label'].value_counts()

label
1    16938
2    14848
3     1713
4     1600
0     1158
Name: count, dtype: int64

In [9]:
# model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096")

In [10]:
max_length= 1200
class CustomDataset(Dataset):
    def __init__(self, tokenizer, df):
        # Initialize thetokenizer
        self.tokenizer = tokenizer

        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        # Get the text and label from the dataframe
        text = self.df.iloc[index]['text']
        label = self.df.iloc[index]['label']

        # Tokenize the text and convert it to input IDs
        inputs = self.tokenizer(
            text,
            None,
            add_special_tokens=False,
            max_length=max_length,
            padding='max_length',
            truncation=True,
            return_token_type_ids=False,
            return_attention_mask=True,
            return_tensors='pt',
        )


        # Return the input IDs and label as PyTorch tensors
        return {
            'input_ids': inputs['input_ids'][0],
            'attention_mask': inputs['attention_mask'][0],
            # 'token_type_ids': inputs['token_type_ids'][0],
            'label': torch.tensor(label, dtype=torch.int64),
        }

# datasetclass = CustomDataset(tokenizer, train)
train_dataset = CustomDataset(tokenizer, train)
val_dataset = CustomDataset(tokenizer, val)

# DataLoader
batch_size = 24
train_dataloader = tqdm(DataLoader(train_dataset, batch_size=batch_size, shuffle=True))
val_dataloader = tqdm(DataLoader(val_dataset, batch_size=batch_size, shuffle=True))


  0%|          | 0/22759 [00:00<?, ?it/s]

In [11]:
class TransformerModel(nn.Module):
    def __init__(self, num_labels):
        super(TransformerModel, self).__init__()
        
        self.Longformer = LongformerModel.from_pretrained("allenai/longformer-base-4096")
        # self.xlnet.resize_token_embeddings(num_tokens)
        # self.transformer_encoder = TransformerEncoder(TransformerEncoderLayer(d_model=hidden_size, nhead=num_heads), num_layers=num_layers)
        #self.transformer_decoder = TransformerDecoder(TransformerDecoderLayer(d_model=hidden_size, nhead=num_heads), num_layers=num_layers)
        #self.transformer = Transformer(nhead=16, num_encoder_layers=6, num_decoder_layers = 6)
        self.decoder = nn.Linear(self.Longformer.config.hidden_size, num_labels) 
        # self.fc1 = nn.Linear(num_tokens, 2)
        # self.fc2 = nn.Linear(num_tokens, 2)
        # self.fc3 = nn.Linear(num_tokens, 5)
        # self.num_classes = num_classes
        # self.classifiers = nn.ModuleList([nn.Linear(self.roberta.config.hidden_size, num_classes[i]) for i in range(len(num_classes))])
        # self.classifiers = nn.ModuleList([nn.Linear(num_tokens, num_classes[i]) for i in range(len(num_classes))])
        # self.tanh = nn.Tanh()

    def forward(self, input_ids, attention_mask):  # src = [bsz, seq_len]
        long_output = self.Longformer(input_ids=input_ids).pooler_output
        # print(long_output.shape)
        # roberta_outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
        # last_hidden_state = outputs.last_hidden_state # Shape: (batch_size, sequence_len, hidden_size)
        # src_embedded = last_hidden_state
        # src_embedded = self.roberta.embeddings(src) # Use RoBERTa model to embed source sequence output: [bsz, seq_len, features,i.e. hidden_dim] [20, 100, 768]
        # print("shape of roberta embeddings:", src_embedded.shape)
        #tgt_embedded = self.roberta.embeddings(tgt) # Use RoBERTa model to embed target sequence
        # src_embedded = src_embedded # output: [bsz, seq_len, features] 
        # src_embedded = torch.cat([t1,t2,t3, src_embedded],1)

        # t1 = torch.cat(src_embedded.size(0) * [t1])
        # t2 = torch.cat(src_embedded.size(0) * [t2])
        # t3 = torch.cat(src_embedded.size(0) * [t3])
        # t = torch.stack([t1,t2,t3], dim=1)
        # task_embedded = torch.cat([t, src_embedded],1)  # output shape: [bsz, seq_len, features] [8, 203, 768]

        # memory = self.transformer_encoder(src_embedded)  # output shape: [bsz, seq_len, features] [8, 203, 768]
        # print("shape after transformer encoder layer:", memory.shape)
        #output = self.transformer_decoder(tgt_embedded, memory)
        #print("shape after transformer decoder layer:", output.shape)

        output = self.decoder(long_output)  # output shape: [bsz, seq_len, vocab_size] [8, 203, 50k]
        # print("shape after transformer decoder layer:", output.shape, output.dtype)
        # task1_output = self.fc1(output[:,0,:])
        # task2_output = self.fc2(output[:,1,:])
        # task3_output = self.fc3(output[:,2,:num_classes])
        # ae_output = output[:,len(self.num_classes):,:]
        # ae_output = output[:,:,:]
        # print("shape after final linear layer:", output.shape)
        # task_logits = [classifier(pooled_output) for classifier in self.classifiers]
        # task_logits = []

        # pooled_outputs = [output[:,i,:] for i in range(len(self.num_classes))] # output shape : [bsz, 1, vocab_size]

        # for classifier, pooled_output in zip(self.classifiers, pooled_outputs):
        #     # pooled_output = self.tanh(pooled_output)
        #     logits = classifier(pooled_output)
        #     task_logits.append(logits)
        
        return output

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')


In [12]:
num_labels = 5

model = TransformerModel(num_labels).to(device)


Some weights of the model checkpoint at allenai/longformer-base-4096 were not used when initializing LongformerModel: ['lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.layer_norm.weight', 'lm_head.bias', 'lm_head.dense.bias', 'lm_head.dense.weight']
- This IS expected if you are initializing LongformerModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LongformerModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [13]:
# num_epochs = 5
learning_rate = 2e-5
class_weights = torch.tensor([0.35, 0.03, 0.03, 0.25, 0.34])

# Set optimizer and learning rate scheduler
criterion = nn.CrossEntropyLoss(weight=class_weights).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

In [14]:
def get_labels(logit, targets):
    """
    Calculate accuracy and macro F1 score for each class
    """
    # pos = list(task_dict.keys()).index(task_name)
    # mask = torch.arange(targets.shape[0]).to(device)
    # task_idx = mask[targets[:,pos] != 99]
    output = logit
    true_label = targets
    # print("shapes for label:", output.shape, true_label.shape)
    pred_label = torch.argmax(output, dim=1).flatten().tolist()
    true_label = true_label.flatten().tolist()


    return pred_label, true_label

In [15]:
current_train_loss = []

def train(model: nn.Module) -> None:
    model.train()  # turn on train mode
    total_loss = 0.
    log_interval = 1
    start_time = time.time()
    num_batches = len(train_dataset) // batch_size
    for batch, i in enumerate(train_dataloader):
        data, mask, targets = i.values()
        data = data.to(device)
        mask = mask.to(device)
        targets = targets.to(device)
        # print(data.dtype)        
        # print(data.shape)
        # task_logits, ae_output = model(data)
        output = model(data, mask)
        # t1_out, t2_out, t3_out, auto_output = model(data, t1, t2, t3)
        # loss = custom_loss(logits_task1, logits_task2, logits_task3, targets)
        # print("shape:", data.shape, targets.flatten().shape)
        # print("datatype:", data.dtype, targets.flatten().dtype)
        loss = criterion(output, targets.flatten())


        optimizer.zero_grad()
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        if batch % log_interval == 0 and batch > 0:
            lr = scheduler.get_last_lr()[0]
            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = total_loss / log_interval
            # ppl = np.exp(cur_loss)
            print(f'| epoch {epoch:3d} | {batch:5d}/{num_batches:5d} batches | '
                    f'lr {lr:02.7f} | ms/batch {ms_per_batch:5.2f} | '
                    f'loss {cur_loss:5.5f}')
            total_loss = 0
            start_time = time.time()
        
        if batch == 100:
            break
    current_train_loss.append(cur_loss)


In [16]:
def evaluate(model: nn.Module) -> float:
    model.eval()  # turn on evaluation mode
    total_loss = 0.
    # src_mask = generate_square_subsequent_mask(bptt).to(device)
    with torch.no_grad():
        predictions = []
        true_labels = []
        for batch, i in enumerate(val_dataloader):
            data, mask, targets = i.values()
            data = data.to(device)
            mask = mask.to(device)
            targets = targets.to(device)
            seq_len = data.size(1)
            # logits_task1, logits_task2, logits_task3, ae_output = model(data, mask)
            # task_logits, ae_output = model(data)
            # task_logits, ae_output = model(data, mask)
            output = model(data, mask)
            # t1_out, t2_out, t3_out, auto_output = model(data, t1, t2, t3)
            # loss = custom_loss(logits_task1, logits_task2, logits_task3, targets)
            # loss = custom_loss(logits_task1, logits_task2, logits_task3, ae_output, data, targets)
            loss = criterion(output, targets.flatten())

            total_loss += seq_len * loss.item()

            #get the labels for classification report
            pred_label, true_label = get_labels(output, targets)
            predictions.extend(pred_label)
            true_labels.extend(true_label)
            # if batch == 100:
            #     break

    # Compute overall classification report
    logging.info(f"\n Scores:")
    logging.info(f"\n {classification_report(true_labels, predictions)}")
    return total_loss / (len(val_dataset) - 1)


In [17]:
logging.info(f"#"* 89)
logging.info(f"\n DESCRIPTION-> \n logic: longformer + linear_layer + loss_reweighting(100 batches), model: {tokenizer.name_or_path}, lr:{learning_rate}, max_seq_length: {max_length}")
logging.info('#' * 89)

2023-07-17 11:26:35,082 - INFO - #########################################################################################
2023-07-17 11:26:35,083 - INFO - 
 DESCRIPTION-> 
 logic: longformer + linear_layer + loss_reweighting(100 batches), model: allenai/longformer-base-4096, lr:2e-05, max_seq_length: 1200
2023-07-17 11:26:35,084 - INFO - #########################################################################################


In [18]:
best_val_loss = float('inf')
current_val_loss = []   # for plotting graph of val_loss
epochs = 40
early_stop_thresh = 3

tempdir = '/data1/debajyoti/colie/.temp/'
best_model_params_path = os.path.join(tempdir, f"best_model_params_{time.asctime().replace(' ','_')}.pt")

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train(model)
    val_loss = evaluate(model)
    current_val_loss.append(val_loss)
    # val_ppl = np.exp(val_loss)
    elapsed = time.time() - epoch_start_time
    logging.info('-' * 89)
    logging.info(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | '
        f'valid loss {val_loss:5.5f}')
    logging.info('-' * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_epoch = epoch
        torch.save(model.state_dict(), best_model_params_path)
    elif epoch - best_epoch > early_stop_thresh:
        logging.info("Early stopped training at epoch %d" % epoch)
        break  # terminate the training loop

    scheduler.step()
model.load_state_dict(torch.load(best_model_params_path)) # load best model states


  0%|          | 2/22759 [00:10<29:54:27,  4.73s/it]

| epoch   1 |     1/22758 batches | lr 0.0000200 | ms/batch 6687.19 | loss 3.31908


  0%|          | 3/22759 [00:13<24:43:11,  3.91s/it]

| epoch   1 |     2/22758 batches | lr 0.0000200 | ms/batch 2934.17 | loss 1.57182


  0%|          | 4/22759 [00:16<22:13:32,  3.52s/it]

| epoch   1 |     3/22758 batches | lr 0.0000200 | ms/batch 2911.35 | loss 1.62768


  0%|          | 5/22759 [00:19<20:51:32,  3.30s/it]

| epoch   1 |     4/22758 batches | lr 0.0000200 | ms/batch 2916.93 | loss 1.60275


  0%|          | 6/22759 [00:21<20:01:54,  3.17s/it]

| epoch   1 |     5/22758 batches | lr 0.0000200 | ms/batch 2915.44 | loss 1.63506


  0%|          | 7/22759 [00:24<19:30:47,  3.09s/it]

| epoch   1 |     6/22758 batches | lr 0.0000200 | ms/batch 2918.64 | loss 1.59265


  0%|          | 8/22759 [00:27<19:09:20,  3.03s/it]

| epoch   1 |     7/22758 batches | lr 0.0000200 | ms/batch 2910.22 | loss 1.60107


  0%|          | 9/22759 [00:30<18:55:36,  3.00s/it]

| epoch   1 |     8/22758 batches | lr 0.0000200 | ms/batch 2915.50 | loss 1.57812


  0%|          | 10/22759 [00:33<18:47:11,  2.97s/it]

| epoch   1 |     9/22758 batches | lr 0.0000200 | ms/batch 2923.52 | loss 1.66095


  0%|          | 11/22759 [00:36<18:39:57,  2.95s/it]

| epoch   1 |    10/22758 batches | lr 0.0000200 | ms/batch 2910.73 | loss 1.62084


  0%|          | 12/22759 [00:39<18:34:31,  2.94s/it]

| epoch   1 |    11/22758 batches | lr 0.0000200 | ms/batch 2907.30 | loss 1.60792


  0%|          | 13/22759 [00:42<18:31:40,  2.93s/it]

| epoch   1 |    12/22758 batches | lr 0.0000200 | ms/batch 2915.10 | loss 1.59776


  0%|          | 14/22759 [00:45<18:29:12,  2.93s/it]

| epoch   1 |    13/22758 batches | lr 0.0000200 | ms/batch 2911.31 | loss 1.60637


  0%|          | 15/22759 [00:48<18:27:35,  2.92s/it]

| epoch   1 |    14/22758 batches | lr 0.0000200 | ms/batch 2912.11 | loss 1.55568


  0%|          | 16/22759 [00:51<18:26:25,  2.92s/it]

| epoch   1 |    15/22758 batches | lr 0.0000200 | ms/batch 2911.92 | loss 1.55268


  0%|          | 17/22759 [00:53<18:26:52,  2.92s/it]

| epoch   1 |    16/22758 batches | lr 0.0000200 | ms/batch 2923.27 | loss 1.56571


  0%|          | 18/22759 [00:56<18:25:44,  2.92s/it]

| epoch   1 |    17/22758 batches | lr 0.0000200 | ms/batch 2910.53 | loss 1.54393


  0%|          | 19/22759 [00:59<18:27:58,  2.92s/it]

| epoch   1 |    18/22758 batches | lr 0.0000200 | ms/batch 2936.62 | loss 1.54037


  0%|          | 20/22759 [01:02<18:29:54,  2.93s/it]

| epoch   1 |    19/22758 batches | lr 0.0000200 | ms/batch 2940.67 | loss 1.70009


  0%|          | 21/22759 [01:05<18:30:36,  2.93s/it]

| epoch   1 |    20/22758 batches | lr 0.0000200 | ms/batch 2935.19 | loss 1.62295


  0%|          | 22/22759 [01:08<18:31:11,  2.93s/it]

| epoch   1 |    21/22758 batches | lr 0.0000200 | ms/batch 2936.04 | loss 1.47377


  0%|          | 23/22759 [01:11<18:31:01,  2.93s/it]

| epoch   1 |    22/22758 batches | lr 0.0000200 | ms/batch 2931.17 | loss 1.58164


  0%|          | 24/22759 [01:14<18:35:39,  2.94s/it]

| epoch   1 |    23/22758 batches | lr 0.0000200 | ms/batch 2972.83 | loss 1.57515


  0%|          | 25/22759 [01:17<18:33:39,  2.94s/it]

| epoch   1 |    24/22758 batches | lr 0.0000200 | ms/batch 2926.31 | loss 1.46694


  0%|          | 26/22759 [01:20<18:32:03,  2.94s/it]

| epoch   1 |    25/22758 batches | lr 0.0000200 | ms/batch 2924.85 | loss 1.48548


  0%|          | 27/22759 [01:23<18:32:18,  2.94s/it]

| epoch   1 |    26/22758 batches | lr 0.0000200 | ms/batch 2937.65 | loss 1.87018


  0%|          | 28/22759 [01:26<18:29:53,  2.93s/it]

| epoch   1 |    27/22758 batches | lr 0.0000200 | ms/batch 2914.90 | loss 1.58744


  0%|          | 29/22759 [01:29<18:28:43,  2.93s/it]

| epoch   1 |    28/22758 batches | lr 0.0000200 | ms/batch 2919.50 | loss 1.52318


  0%|          | 30/22759 [01:32<18:27:46,  2.92s/it]

| epoch   1 |    29/22758 batches | lr 0.0000200 | ms/batch 2918.75 | loss 1.58006


  0%|          | 31/22759 [01:35<18:31:02,  2.93s/it]

| epoch   1 |    30/22758 batches | lr 0.0000200 | ms/batch 2953.27 | loss 1.71528


  0%|          | 32/22759 [01:37<18:28:41,  2.93s/it]

| epoch   1 |    31/22758 batches | lr 0.0000200 | ms/batch 2912.71 | loss 1.49429


  0%|          | 33/22759 [01:40<18:27:03,  2.92s/it]

| epoch   1 |    32/22758 batches | lr 0.0000200 | ms/batch 2913.02 | loss 1.67318


  0%|          | 34/22759 [01:43<18:27:17,  2.92s/it]

| epoch   1 |    33/22758 batches | lr 0.0000200 | ms/batch 2925.09 | loss 1.80831


  0%|          | 35/22759 [01:46<18:29:08,  2.93s/it]

| epoch   1 |    34/22758 batches | lr 0.0000200 | ms/batch 2939.98 | loss 1.75395


  0%|          | 36/22759 [01:49<18:28:44,  2.93s/it]

| epoch   1 |    35/22758 batches | lr 0.0000200 | ms/batch 2925.45 | loss 1.71122


  0%|          | 37/22759 [01:52<18:29:09,  2.93s/it]

| epoch   1 |    36/22758 batches | lr 0.0000200 | ms/batch 2931.60 | loss 1.52372


  0%|          | 38/22759 [01:55<18:32:53,  2.94s/it]

| epoch   1 |    37/22758 batches | lr 0.0000200 | ms/batch 2962.04 | loss 1.64488


  0%|          | 39/22759 [01:58<18:31:08,  2.93s/it]

| epoch   1 |    38/22758 batches | lr 0.0000200 | ms/batch 2923.72 | loss 1.63901


  0%|          | 40/22759 [02:01<18:30:31,  2.93s/it]

| epoch   1 |    39/22758 batches | lr 0.0000200 | ms/batch 2929.29 | loss 1.56547


  0%|          | 41/22759 [02:04<18:29:05,  2.93s/it]

| epoch   1 |    40/22758 batches | lr 0.0000200 | ms/batch 2920.64 | loss 1.62322


  0%|          | 42/22759 [02:07<18:28:39,  2.93s/it]

| epoch   1 |    41/22758 batches | lr 0.0000200 | ms/batch 2925.61 | loss 1.58416


  0%|          | 43/22759 [02:10<18:28:51,  2.93s/it]

| epoch   1 |    42/22758 batches | lr 0.0000200 | ms/batch 2930.24 | loss 1.58975


  0%|          | 44/22759 [02:13<18:29:34,  2.93s/it]

| epoch   1 |    43/22758 batches | lr 0.0000200 | ms/batch 2935.48 | loss 1.61509


  0%|          | 45/22759 [02:16<18:31:00,  2.93s/it]

| epoch   1 |    44/22758 batches | lr 0.0000200 | ms/batch 2943.81 | loss 1.67200


  0%|          | 46/22759 [02:18<18:29:29,  2.93s/it]

| epoch   1 |    45/22758 batches | lr 0.0000200 | ms/batch 2921.67 | loss 1.52760


  0%|          | 47/22759 [02:21<18:28:00,  2.93s/it]

| epoch   1 |    46/22758 batches | lr 0.0000200 | ms/batch 2918.07 | loss 1.58483


  0%|          | 48/22759 [02:24<18:34:50,  2.95s/it]

| epoch   1 |    47/22758 batches | lr 0.0000200 | ms/batch 2987.68 | loss 1.68172


  0%|          | 49/22759 [02:27<18:32:45,  2.94s/it]

| epoch   1 |    48/22758 batches | lr 0.0000200 | ms/batch 2927.29 | loss 1.58837


  0%|          | 50/22759 [02:30<18:30:39,  2.93s/it]

| epoch   1 |    49/22758 batches | lr 0.0000200 | ms/batch 2921.66 | loss 1.54884


  0%|          | 51/22759 [02:33<18:29:26,  2.93s/it]

| epoch   1 |    50/22758 batches | lr 0.0000200 | ms/batch 2923.93 | loss 1.67857


  0%|          | 52/22759 [02:36<18:28:46,  2.93s/it]

| epoch   1 |    51/22758 batches | lr 0.0000200 | ms/batch 2925.23 | loss 1.64099


  0%|          | 53/22759 [02:39<18:28:33,  2.93s/it]

| epoch   1 |    52/22758 batches | lr 0.0000200 | ms/batch 2928.35 | loss 1.64869


  0%|          | 54/22759 [02:42<18:27:52,  2.93s/it]

| epoch   1 |    53/22758 batches | lr 0.0000200 | ms/batch 2922.43 | loss 1.55526


  0%|          | 55/22759 [02:45<18:28:36,  2.93s/it]

| epoch   1 |    54/22758 batches | lr 0.0000200 | ms/batch 2934.45 | loss 1.63556


  0%|          | 56/22759 [02:48<18:27:51,  2.93s/it]

| epoch   1 |    55/22758 batches | lr 0.0000200 | ms/batch 2923.50 | loss 1.49992


  0%|          | 57/22759 [02:51<18:27:38,  2.93s/it]

| epoch   1 |    56/22758 batches | lr 0.0000200 | ms/batch 2926.28 | loss 1.60646


  0%|          | 58/22759 [02:54<18:26:52,  2.93s/it]

| epoch   1 |    57/22758 batches | lr 0.0000200 | ms/batch 2921.03 | loss 1.61651


  0%|          | 59/22759 [02:57<18:33:25,  2.94s/it]

| epoch   1 |    58/22758 batches | lr 0.0000200 | ms/batch 2983.54 | loss 1.53537


  0%|          | 60/22759 [03:00<18:31:39,  2.94s/it]

| epoch   1 |    59/22758 batches | lr 0.0000200 | ms/batch 2927.74 | loss 1.66742


  0%|          | 61/22759 [03:02<18:29:15,  2.93s/it]

| epoch   1 |    60/22758 batches | lr 0.0000200 | ms/batch 2917.45 | loss 1.59098


  0%|          | 62/22759 [03:05<18:27:25,  2.93s/it]

| epoch   1 |    61/22758 batches | lr 0.0000200 | ms/batch 2916.40 | loss 1.58944


  0%|          | 63/22759 [03:08<18:26:26,  2.93s/it]

| epoch   1 |    62/22758 batches | lr 0.0000200 | ms/batch 2919.14 | loss 1.39825


  0%|          | 64/22759 [03:11<18:25:18,  2.92s/it]

| epoch   1 |    63/22758 batches | lr 0.0000200 | ms/batch 2915.38 | loss 1.66447


  0%|          | 65/22759 [03:14<18:24:39,  2.92s/it]

| epoch   1 |    64/22758 batches | lr 0.0000200 | ms/batch 2916.77 | loss 1.64497


  0%|          | 66/22759 [03:17<18:29:04,  2.93s/it]

| epoch   1 |    65/22758 batches | lr 0.0000200 | ms/batch 2959.88 | loss 1.50000


  0%|          | 67/22759 [03:20<18:27:35,  2.93s/it]

| epoch   1 |    66/22758 batches | lr 0.0000200 | ms/batch 2919.66 | loss 1.61260


  0%|          | 68/22759 [03:23<18:26:11,  2.93s/it]

| epoch   1 |    67/22758 batches | lr 0.0000200 | ms/batch 2916.63 | loss 1.61354


  0%|          | 69/22759 [03:26<18:25:11,  2.92s/it]

| epoch   1 |    68/22758 batches | lr 0.0000200 | ms/batch 2916.53 | loss 1.41785


  0%|          | 70/22759 [03:29<18:24:53,  2.92s/it]

| epoch   1 |    69/22758 batches | lr 0.0000200 | ms/batch 2920.23 | loss 1.58308


  0%|          | 71/22759 [03:32<18:24:28,  2.92s/it]

| epoch   1 |    70/22758 batches | lr 0.0000200 | ms/batch 2918.46 | loss 1.59955


  0%|          | 72/22759 [03:35<18:26:34,  2.93s/it]

| epoch   1 |    71/22758 batches | lr 0.0000200 | ms/batch 2939.68 | loss 1.49589


  0%|          | 73/22759 [03:38<18:25:55,  2.92s/it]

| epoch   1 |    72/22758 batches | lr 0.0000200 | ms/batch 2921.06 | loss 1.55118


  0%|          | 74/22759 [03:40<18:25:02,  2.92s/it]

| epoch   1 |    73/22758 batches | lr 0.0000200 | ms/batch 2917.44 | loss 1.48561


  0%|          | 75/22759 [03:43<18:24:21,  2.92s/it]

| epoch   1 |    74/22758 batches | lr 0.0000200 | ms/batch 2916.97 | loss 1.37814


  0%|          | 76/22759 [03:46<18:24:42,  2.92s/it]

| epoch   1 |    75/22758 batches | lr 0.0000200 | ms/batch 2924.62 | loss 1.67272


  0%|          | 77/22759 [03:49<18:24:43,  2.92s/it]

| epoch   1 |    76/22758 batches | lr 0.0000200 | ms/batch 2922.58 | loss 1.76804


  0%|          | 78/22759 [03:52<18:23:46,  2.92s/it]

| epoch   1 |    77/22758 batches | lr 0.0000200 | ms/batch 2914.26 | loss 1.64887


  0%|          | 79/22759 [03:55<18:23:28,  2.92s/it]

| epoch   1 |    78/22758 batches | lr 0.0000200 | ms/batch 2917.62 | loss 1.51614


  0%|          | 80/22759 [03:58<18:28:43,  2.93s/it]

| epoch   1 |    79/22758 batches | lr 0.0000200 | ms/batch 2965.81 | loss 1.67785


  0%|          | 81/22759 [04:01<18:27:13,  2.93s/it]

| epoch   1 |    80/22758 batches | lr 0.0000200 | ms/batch 2920.36 | loss 1.28234


  0%|          | 82/22759 [04:04<18:25:24,  2.92s/it]

| epoch   1 |    81/22758 batches | lr 0.0000200 | ms/batch 2913.81 | loss 1.28396


  0%|          | 83/22759 [04:07<18:24:23,  2.92s/it]

| epoch   1 |    82/22758 batches | lr 0.0000200 | ms/batch 2916.11 | loss 1.62336


  0%|          | 84/22759 [04:10<18:23:45,  2.92s/it]

| epoch   1 |    83/22758 batches | lr 0.0000200 | ms/batch 2916.90 | loss 1.69428


  0%|          | 85/22759 [04:13<18:23:03,  2.92s/it]

| epoch   1 |    84/22758 batches | lr 0.0000200 | ms/batch 2914.80 | loss 1.60492


  0%|          | 86/22759 [04:16<18:22:24,  2.92s/it]

| epoch   1 |    85/22758 batches | lr 0.0000200 | ms/batch 2913.51 | loss 1.52354


  0%|          | 87/22759 [04:18<18:22:40,  2.92s/it]

| epoch   1 |    86/22758 batches | lr 0.0000200 | ms/batch 2920.08 | loss 1.54339


  0%|          | 88/22759 [04:21<18:22:56,  2.92s/it]

| epoch   1 |    87/22758 batches | lr 0.0000200 | ms/batch 2920.68 | loss 1.67685


  0%|          | 89/22759 [04:24<18:22:52,  2.92s/it]

| epoch   1 |    88/22758 batches | lr 0.0000200 | ms/batch 2918.80 | loss 1.57377


  0%|          | 90/22759 [04:27<18:22:23,  2.92s/it]

| epoch   1 |    89/22758 batches | lr 0.0000200 | ms/batch 2914.96 | loss 1.53873


  0%|          | 91/22759 [04:30<18:22:45,  2.92s/it]

| epoch   1 |    90/22758 batches | lr 0.0000200 | ms/batch 2921.40 | loss 1.75393


  0%|          | 92/22759 [04:33<18:22:32,  2.92s/it]

| epoch   1 |    91/22758 batches | lr 0.0000200 | ms/batch 2917.38 | loss 1.66978


  0%|          | 93/22759 [04:36<18:22:22,  2.92s/it]

| epoch   1 |    92/22758 batches | lr 0.0000200 | ms/batch 2917.27 | loss 1.55546


  0%|          | 94/22759 [04:39<18:24:58,  2.93s/it]

| epoch   1 |    93/22758 batches | lr 0.0000200 | ms/batch 2941.38 | loss 1.80391


  0%|          | 95/22759 [04:42<18:24:41,  2.92s/it]

| epoch   1 |    94/22758 batches | lr 0.0000200 | ms/batch 2922.94 | loss 1.51705


  0%|          | 96/22759 [04:45<18:24:25,  2.92s/it]

| epoch   1 |    95/22758 batches | lr 0.0000200 | ms/batch 2922.56 | loss 1.59760


  0%|          | 97/22759 [04:48<18:23:42,  2.92s/it]

| epoch   1 |    96/22758 batches | lr 0.0000200 | ms/batch 2918.00 | loss 1.43044


  0%|          | 98/22759 [04:51<18:23:21,  2.92s/it]

| epoch   1 |    97/22758 batches | lr 0.0000200 | ms/batch 2919.44 | loss 1.54504


  0%|          | 99/22759 [04:54<18:22:19,  2.92s/it]

| epoch   1 |    98/22758 batches | lr 0.0000200 | ms/batch 2912.53 | loss 1.54301


  0%|          | 100/22759 [04:56<18:20:25,  2.91s/it]

| epoch   1 |    99/22758 batches | lr 0.0000200 | ms/batch 2902.35 | loss 1.58384


  0%|          | 100/22759 [04:59<18:52:31,  3.00s/it]

| epoch   1 |   100/22758 batches | lr 0.0000200 | ms/batch 2971.93 | loss 1.54353



100%|██████████| 1511/1511 [24:05<00:00,  1.05it/s]
2023-07-17 11:50:37,143 - INFO - 
 Scores:
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
2023-07-17 11:50:37,205 - INFO - 
               precision    recall  f1-score   support

           0       0.00      0.00      0.00      1158
           1       0.50      0.02      0.03     16938
           2       0.43      0.85      0.57     14848
           3       0.00      0.00      0.00      1713
           4       0.04      0.15      0.06      1600

    accuracy                           0.36     36257
   macro avg       0.19      0.20      0.13     36257
weighted avg       0.41      0.36      0.25     36257

2023-07-17 11:50:37,206 - INFO - -----------------------------------------------------------------------------------------
2023-07-17 11:50:37,207 - INFO - | end of epoch   1 | time: 1442.11s | valid loss 80.29001
2

| epoch   2 |     1/22758 batches | lr 0.0000190 | ms/batch 5853.72 | loss 3.25321
| epoch   2 |     2/22758 batches | lr 0.0000190 | ms/batch 2944.87 | loss 1.68745
| epoch   2 |     3/22758 batches | lr 0.0000190 | ms/batch 2970.22 | loss 1.76145
| epoch   2 |     4/22758 batches | lr 0.0000190 | ms/batch 2919.10 | loss 1.71872
| epoch   2 |     5/22758 batches | lr 0.0000190 | ms/batch 2920.87 | loss 1.59743
| epoch   2 |     6/22758 batches | lr 0.0000190 | ms/batch 2924.06 | loss 1.55215
| epoch   2 |     7/22758 batches | lr 0.0000190 | ms/batch 2923.04 | loss 1.66996
| epoch   2 |     8/22758 batches | lr 0.0000190 | ms/batch 2925.49 | loss 1.54641
| epoch   2 |     9/22758 batches | lr 0.0000190 | ms/batch 2931.90 | loss 1.55562
| epoch   2 |    10/22758 batches | lr 0.0000190 | ms/batch 2921.41 | loss 1.67077
| epoch   2 |    11/22758 batches | lr 0.0000190 | ms/batch 2924.22 | loss 1.63249
| epoch   2 |    12/22758 batches | lr 0.0000190 | ms/batch 2916.35 | loss 1.59587
| ep

2023-07-17 12:14:36,253 - INFO - 
 Scores:
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
2023-07-17 12:14:36,304 - INFO - 
               precision    recall  f1-score   support

           0       0.00      0.00      0.00      1158
           1       0.52      0.61      0.56     16938
           2       0.47      0.02      0.04     14848
           3       0.00      0.00      0.00      1713
           4       0.05      0.49      0.09      1600

    accuracy                           0.31     36257
   macro avg       0.21      0.22      0.14     36257
weighted avg       0.44      0.31      0.28     36257

2023-07-17 12:14:36,305 - INFO - -----------------------------------------------------------------------------------------
2023-07-17 12:14:36,306 - INFO - | end of epoch   2 | time: 1438.45s | valid loss 79.39922
2023-07-17 12:14:36,306 - INFO - ---------------------

| epoch   3 |     1/22758 batches | lr 0.0000181 | ms/batch 5847.81 | loss 3.15688
| epoch   3 |     2/22758 batches | lr 0.0000181 | ms/batch 2909.91 | loss 1.63082
| epoch   3 |     3/22758 batches | lr 0.0000181 | ms/batch 2912.65 | loss 1.53810
| epoch   3 |     4/22758 batches | lr 0.0000181 | ms/batch 2918.40 | loss 1.54178
| epoch   3 |     5/22758 batches | lr 0.0000181 | ms/batch 2910.59 | loss 1.64238
| epoch   3 |     6/22758 batches | lr 0.0000181 | ms/batch 2926.31 | loss 1.50475
| epoch   3 |     7/22758 batches | lr 0.0000181 | ms/batch 2921.47 | loss 1.52614
| epoch   3 |     8/22758 batches | lr 0.0000181 | ms/batch 2918.80 | loss 1.57445
| epoch   3 |     9/22758 batches | lr 0.0000181 | ms/batch 2918.49 | loss 1.55922
| epoch   3 |    10/22758 batches | lr 0.0000181 | ms/batch 2917.41 | loss 1.50498
| epoch   3 |    11/22758 batches | lr 0.0000181 | ms/batch 2925.17 | loss 1.38305
| epoch   3 |    12/22758 batches | lr 0.0000181 | ms/batch 2915.53 | loss 1.47678
| ep

2023-07-17 12:39:04,186 - INFO - 
 Scores:
2023-07-17 12:39:04,237 - INFO - 
               precision    recall  f1-score   support

           0       0.14      0.40      0.21      1158
           1       0.66      0.44      0.53     16938
           2       0.70      0.06      0.12     14848
           3       0.22      0.04      0.07      1713
           4       0.04      0.54      0.08      1600

    accuracy                           0.27     36257
   macro avg       0.35      0.30      0.20     36257
weighted avg       0.61      0.27      0.31     36257

2023-07-17 12:39:04,238 - INFO - -----------------------------------------------------------------------------------------
2023-07-17 12:39:04,239 - INFO - | end of epoch   3 | time: 1441.80s | valid loss 74.36805
2023-07-17 12:39:04,240 - INFO - -----------------------------------------------------------------------------------------


| epoch   4 |     1/22758 batches | lr 0.0000171 | ms/batch 5832.39 | loss 2.93668
| epoch   4 |     2/22758 batches | lr 0.0000171 | ms/batch 2917.87 | loss 1.31364
| epoch   4 |     3/22758 batches | lr 0.0000171 | ms/batch 2919.47 | loss 1.47680
| epoch   4 |     4/22758 batches | lr 0.0000171 | ms/batch 2915.70 | loss 1.43661
| epoch   4 |     5/22758 batches | lr 0.0000171 | ms/batch 2929.09 | loss 1.74241
| epoch   4 |     6/22758 batches | lr 0.0000171 | ms/batch 2915.97 | loss 1.20938
| epoch   4 |     7/22758 batches | lr 0.0000171 | ms/batch 2920.95 | loss 1.22711
| epoch   4 |     8/22758 batches | lr 0.0000171 | ms/batch 2918.13 | loss 1.40864
| epoch   4 |     9/22758 batches | lr 0.0000171 | ms/batch 2919.30 | loss 1.91864
| epoch   4 |    10/22758 batches | lr 0.0000171 | ms/batch 2921.78 | loss 1.45985
| epoch   4 |    11/22758 batches | lr 0.0000171 | ms/batch 2920.68 | loss 1.19569
| epoch   4 |    12/22758 batches | lr 0.0000171 | ms/batch 2949.60 | loss 1.33313
| ep

2023-07-17 13:03:11,414 - INFO - 
 Scores:
2023-07-17 13:03:11,468 - INFO - 
               precision    recall  f1-score   support

           0       0.08      0.81      0.15      1158
           1       0.51      0.43      0.47     16938
           2       0.71      0.45      0.55     14848
           3       0.39      0.10      0.16      1713
           4       0.09      0.04      0.06      1600

    accuracy                           0.42     36257
   macro avg       0.36      0.37      0.28     36257
weighted avg       0.55      0.42      0.46     36257

2023-07-17 13:03:11,469 - INFO - -----------------------------------------------------------------------------------------
2023-07-17 13:03:11,470 - INFO - | end of epoch   4 | time: 1435.76s | valid loss 73.16209
2023-07-17 13:03:11,471 - INFO - -----------------------------------------------------------------------------------------


| epoch   5 |     1/22758 batches | lr 0.0000163 | ms/batch 5862.14 | loss 2.99125
| epoch   5 |     2/22758 batches | lr 0.0000163 | ms/batch 2920.59 | loss 1.31630
| epoch   5 |     3/22758 batches | lr 0.0000163 | ms/batch 2941.11 | loss 1.45659
| epoch   5 |     4/22758 batches | lr 0.0000163 | ms/batch 2923.52 | loss 1.33540
| epoch   5 |     5/22758 batches | lr 0.0000163 | ms/batch 2919.42 | loss 1.64563
| epoch   5 |     6/22758 batches | lr 0.0000163 | ms/batch 2918.11 | loss 1.48746
| epoch   5 |     7/22758 batches | lr 0.0000163 | ms/batch 2910.01 | loss 1.87122
| epoch   5 |     8/22758 batches | lr 0.0000163 | ms/batch 2907.34 | loss 1.30860
| epoch   5 |     9/22758 batches | lr 0.0000163 | ms/batch 2911.46 | loss 1.25273
| epoch   5 |    10/22758 batches | lr 0.0000163 | ms/batch 2946.29 | loss 1.36217
| epoch   5 |    11/22758 batches | lr 0.0000163 | ms/batch 2923.79 | loss 1.50565
| epoch   5 |    12/22758 batches | lr 0.0000163 | ms/batch 2908.73 | loss 2.03124
| ep

2023-07-17 13:27:18,116 - INFO - 
 Scores:
2023-07-17 13:27:18,178 - INFO - 
               precision    recall  f1-score   support

           0       0.14      0.44      0.21      1158
           1       0.68      0.51      0.58     16938
           2       0.62      0.56      0.59     14848
           3       0.12      0.34      0.18      1713
           4       0.04      0.04      0.04      1600

    accuracy                           0.50     36257
   macro avg       0.32      0.38      0.32     36257
weighted avg       0.59      0.50      0.53     36257

2023-07-17 13:27:18,179 - INFO - -----------------------------------------------------------------------------------------
2023-07-17 13:27:18,180 - INFO - | end of epoch   5 | time: 1439.19s | valid loss 72.45588
2023-07-17 13:27:18,181 - INFO - -----------------------------------------------------------------------------------------


| epoch   6 |     1/22758 batches | lr 0.0000155 | ms/batch 5856.61 | loss 3.16991
| epoch   6 |     2/22758 batches | lr 0.0000155 | ms/batch 2920.38 | loss 1.28221
| epoch   6 |     3/22758 batches | lr 0.0000155 | ms/batch 2921.18 | loss 1.10255
| epoch   6 |     4/22758 batches | lr 0.0000155 | ms/batch 2917.20 | loss 1.85213
| epoch   6 |     5/22758 batches | lr 0.0000155 | ms/batch 2922.84 | loss 1.35475
| epoch   6 |     6/22758 batches | lr 0.0000155 | ms/batch 2990.14 | loss 1.16973
| epoch   6 |     7/22758 batches | lr 0.0000155 | ms/batch 2916.05 | loss 1.42471
| epoch   6 |     8/22758 batches | lr 0.0000155 | ms/batch 2924.55 | loss 1.54204
| epoch   6 |     9/22758 batches | lr 0.0000155 | ms/batch 2924.57 | loss 1.36597
| epoch   6 |    10/22758 batches | lr 0.0000155 | ms/batch 2922.74 | loss 1.50469
| epoch   6 |    11/22758 batches | lr 0.0000155 | ms/batch 2929.85 | loss 1.52680
| epoch   6 |    12/22758 batches | lr 0.0000155 | ms/batch 2920.41 | loss 1.52535
| ep

2023-07-17 13:51:30,961 - INFO - 
 Scores:
2023-07-17 13:51:31,017 - INFO - 
               precision    recall  f1-score   support

           0       0.13      0.50      0.21      1158
           1       0.72      0.42      0.53     16938
           2       0.75      0.48      0.59     14848
           3       0.30      0.23      0.26      1713
           4       0.04      0.30      0.08      1600

    accuracy                           0.44     36257
   macro avg       0.39      0.39      0.33     36257
weighted avg       0.66      0.44      0.51     36257

2023-07-17 13:51:31,019 - INFO - -----------------------------------------------------------------------------------------
2023-07-17 13:51:31,020 - INFO - | end of epoch   6 | time: 1441.78s | valid loss 70.27165
2023-07-17 13:51:31,021 - INFO - -----------------------------------------------------------------------------------------


| epoch   7 |     1/22758 batches | lr 0.0000147 | ms/batch 5852.09 | loss 2.31261
| epoch   7 |     2/22758 batches | lr 0.0000147 | ms/batch 2920.14 | loss 0.85204
| epoch   7 |     3/22758 batches | lr 0.0000147 | ms/batch 2922.74 | loss 1.60093
| epoch   7 |     4/22758 batches | lr 0.0000147 | ms/batch 2983.99 | loss 1.31920
| epoch   7 |     5/22758 batches | lr 0.0000147 | ms/batch 2920.26 | loss 1.68314
| epoch   7 |     6/22758 batches | lr 0.0000147 | ms/batch 2923.26 | loss 1.05793
| epoch   7 |     7/22758 batches | lr 0.0000147 | ms/batch 2925.82 | loss 1.36122
| epoch   7 |     8/22758 batches | lr 0.0000147 | ms/batch 2917.36 | loss 0.85137
| epoch   7 |     9/22758 batches | lr 0.0000147 | ms/batch 2922.16 | loss 1.18954
| epoch   7 |    10/22758 batches | lr 0.0000147 | ms/batch 2919.98 | loss 1.64406
| epoch   7 |    11/22758 batches | lr 0.0000147 | ms/batch 2955.63 | loss 0.69318
| epoch   7 |    12/22758 batches | lr 0.0000147 | ms/batch 2929.41 | loss 1.64020
| ep

2023-07-17 14:15:35,252 - INFO - 
 Scores:
2023-07-17 14:15:35,303 - INFO - 
               precision    recall  f1-score   support

           0       0.15      0.47      0.22      1158
           1       0.63      0.71      0.66     16938
           2       0.72      0.57      0.64     14848
           3       0.33      0.22      0.27      1713
           4       0.09      0.03      0.04      1600

    accuracy                           0.59     36257
   macro avg       0.38      0.40      0.37     36257
weighted avg       0.61      0.59      0.59     36257

2023-07-17 14:15:35,305 - INFO - -----------------------------------------------------------------------------------------
2023-07-17 14:15:35,306 - INFO - | end of epoch   7 | time: 1436.54s | valid loss 69.13249
2023-07-17 14:15:35,306 - INFO - -----------------------------------------------------------------------------------------


| epoch   8 |     1/22758 batches | lr 0.0000140 | ms/batch 5863.47 | loss 2.78418
| epoch   8 |     2/22758 batches | lr 0.0000140 | ms/batch 2919.72 | loss 1.33664
| epoch   8 |     3/22758 batches | lr 0.0000140 | ms/batch 2905.22 | loss 1.28363
| epoch   8 |     4/22758 batches | lr 0.0000140 | ms/batch 2961.36 | loss 1.28963
| epoch   8 |     5/22758 batches | lr 0.0000140 | ms/batch 2920.38 | loss 1.24685
| epoch   8 |     6/22758 batches | lr 0.0000140 | ms/batch 2919.79 | loss 1.24316
| epoch   8 |     7/22758 batches | lr 0.0000140 | ms/batch 2920.87 | loss 1.44340
| epoch   8 |     8/22758 batches | lr 0.0000140 | ms/batch 2922.06 | loss 1.46030
| epoch   8 |     9/22758 batches | lr 0.0000140 | ms/batch 2916.56 | loss 1.06249
| epoch   8 |    10/22758 batches | lr 0.0000140 | ms/batch 2914.32 | loss 1.30905
| epoch   8 |    11/22758 batches | lr 0.0000140 | ms/batch 2961.82 | loss 1.49124
| epoch   8 |    12/22758 batches | lr 0.0000140 | ms/batch 2922.53 | loss 1.08213
| ep

2023-07-17 14:39:46,372 - INFO - 
 Scores:
2023-07-17 14:39:46,423 - INFO - 
               precision    recall  f1-score   support

           0       0.13      0.62      0.22      1158
           1       0.51      0.74      0.60     16938
           2       0.78      0.25      0.38     14848
           3       0.33      0.18      0.23      1713
           4       0.07      0.03      0.04      1600

    accuracy                           0.48     36257
   macro avg       0.36      0.36      0.30     36257
weighted avg       0.58      0.48      0.46     36257

2023-07-17 14:39:46,424 - INFO - -----------------------------------------------------------------------------------------
2023-07-17 14:39:46,425 - INFO - | end of epoch   8 | time: 1440.15s | valid loss 72.12461
2023-07-17 14:39:46,426 - INFO - -----------------------------------------------------------------------------------------


| epoch   9 |     1/22758 batches | lr 0.0000133 | ms/batch 5847.78 | loss 2.69972
| epoch   9 |     2/22758 batches | lr 0.0000133 | ms/batch 2924.67 | loss 1.27642
| epoch   9 |     3/22758 batches | lr 0.0000133 | ms/batch 2919.94 | loss 1.49561
| epoch   9 |     4/22758 batches | lr 0.0000133 | ms/batch 2921.12 | loss 1.06143
| epoch   9 |     5/22758 batches | lr 0.0000133 | ms/batch 2996.94 | loss 1.36295
| epoch   9 |     6/22758 batches | lr 0.0000133 | ms/batch 2929.97 | loss 1.54357
| epoch   9 |     7/22758 batches | lr 0.0000133 | ms/batch 2927.00 | loss 1.31199
| epoch   9 |     8/22758 batches | lr 0.0000133 | ms/batch 2921.94 | loss 1.06850
| epoch   9 |     9/22758 batches | lr 0.0000133 | ms/batch 2926.84 | loss 1.66200
| epoch   9 |    10/22758 batches | lr 0.0000133 | ms/batch 2926.14 | loss 1.23944
| epoch   9 |    11/22758 batches | lr 0.0000133 | ms/batch 2927.37 | loss 1.54754
| epoch   9 |    12/22758 batches | lr 0.0000133 | ms/batch 2940.66 | loss 1.44903
| ep

2023-07-17 15:03:46,722 - INFO - 
 Scores:
2023-07-17 15:03:46,788 - INFO - 
               precision    recall  f1-score   support

           0       0.14      0.63      0.23      1158
           1       0.72      0.47      0.57     16938
           2       0.64      0.67      0.65     14848
           3       0.24      0.32      0.28      1713
           4       0.05      0.06      0.05      1600

    accuracy                           0.53     36257
   macro avg       0.36      0.43      0.36     36257
weighted avg       0.61      0.53      0.56     36257

2023-07-17 15:03:46,790 - INFO - -----------------------------------------------------------------------------------------
2023-07-17 15:03:46,790 - INFO - | end of epoch   9 | time: 1440.36s | valid loss 67.39701
2023-07-17 15:03:46,791 - INFO - -----------------------------------------------------------------------------------------


| epoch  10 |     1/22758 batches | lr 0.0000126 | ms/batch 5934.34 | loss 3.28735
| epoch  10 |     2/22758 batches | lr 0.0000126 | ms/batch 2911.77 | loss 1.33608
| epoch  10 |     3/22758 batches | lr 0.0000126 | ms/batch 2909.76 | loss 1.57771
| epoch  10 |     4/22758 batches | lr 0.0000126 | ms/batch 2909.54 | loss 1.13079
| epoch  10 |     5/22758 batches | lr 0.0000126 | ms/batch 2921.00 | loss 1.05130
| epoch  10 |     6/22758 batches | lr 0.0000126 | ms/batch 2926.40 | loss 1.14101
| epoch  10 |     7/22758 batches | lr 0.0000126 | ms/batch 2917.56 | loss 1.42674
| epoch  10 |     8/22758 batches | lr 0.0000126 | ms/batch 2922.09 | loss 1.44864
| epoch  10 |     9/22758 batches | lr 0.0000126 | ms/batch 2905.75 | loss 1.31752
| epoch  10 |    10/22758 batches | lr 0.0000126 | ms/batch 2905.81 | loss 1.21934
| epoch  10 |    11/22758 batches | lr 0.0000126 | ms/batch 2905.83 | loss 1.03424
| epoch  10 |    12/22758 batches | lr 0.0000126 | ms/batch 2926.93 | loss 1.42170
| ep

In [None]:
(1/43)/((1/30) + (1/377) + (1/327) + (1/43) + (1/31))