In [1]:
import os
import torch
import torch.utils.data
import torchvision
from PIL import Image
from pycocotools.coco import COCO
import cv2
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score, precision_score, recall_score, cohen_kappa_score, classification_report, confusion_matrix
from scipy.stats import rankdata
from numpy import linalg as LA
# import intel_extension_for_pytorch as ipex

%load_ext autotime

# import albumentations as A
# from albumentations.pytorch import ToTensorV2
import numpy as np
import torch.nn.functional as F

# from torchmetrics.detection.mean_ap import MeanAveragePrecision
from torch.optim.lr_scheduler import MultiStepLR
import time
import os
plt.style.use('ggplot')
device = 'cuda'
import os
import time
import math
from tqdm import tqdm
import logging
from torch.autograd import Variable
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import StepLR, MultiStepLR

seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) 

time: 172 ms (started: 2024-05-06 06:46:56 +06:00)


In [2]:
##################################################################
#   Utils
##################################################################


"""
Tensorboard logger code referenced from:
https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/04-utils/
Other helper functions:
https://github.com/cs230-stanford/cs230-stanford.github.io
"""

import json
import logging
import os
import shutil
import torch
from collections import OrderedDict
from torch.optim.lr_scheduler import _LRScheduler
# import tensorflow as tf
import numpy as np
import scipy.misc 
try:
    from StringIO import StringIO  # Python 2.7
except ImportError:
    from io import BytesIO         # Python 3.x


class Params():
    """Class that loads hyperparameters from a json file.

    Example:
    ```
    params = Params(json_path)
    print(params.learning_rate)
    params.learning_rate = 0.5  # change the value of learning_rate in params
    ```
    """

    def __init__(self, json_path):
        with open(json_path) as f:
            params = json.load(f)
            self.__dict__.update(params)

    def save(self, json_path):
        with open(json_path, 'w') as f:
            json.dump(self.__dict__, f, indent=4)
            
    def update(self, json_path):
        """Loads parameters from json file"""
        with open(json_path) as f:
            params = json.load(f)
            self.__dict__.update(params)

    @property
    def dict(self):
        """Gives dict-like access to Params instance by `params.dict['learning_rate']"""
        return self.__dict__


class RunningAverage():
    """A simple class that maintains the running average of a quantity
    
    Example:
    ```
    loss_avg = RunningAverage()
    loss_avg.update(2)
    loss_avg.update(4)
    loss_avg() = 3
    ```
    """
    def __init__(self):
        self.steps = 0
        self.total = 0
    
    def update(self, val):
        self.total += val
        self.steps += 1
    
    def __call__(self):
        return self.total/float(self.steps)

class AverageMeter(object):

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val*n
        self.count += n
        self.avg = self.sum/self.count

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def set_logger(log_path):
    """Set the logger to log info in terminal and file `log_path`.

    In general, it is useful to have a logger so that every output to the terminal is saved
    in a permanent file. Here we save it to `model_dir/train.log`.

    Example:
    ```
    logging.info("Starting training...")
    ```

    Args:
        log_path: (string) where to log
    """
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    if not logger.handlers:
        # Logging to a file
        file_handler = logging.FileHandler(log_path)
        file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
        logger.addHandler(file_handler)

        # Logging to console
        stream_handler = logging.StreamHandler()
        stream_handler.setFormatter(logging.Formatter('%(message)s'))
        logger.addHandler(stream_handler)


def save_dict_to_json(d, json_path):
    """Saves dict of floats in json file

    Args:
        d: (dict) of float-castable values (np.float, int, float, etc.)
        json_path: (string) path to json file
    """
    with open(json_path, 'w') as f:
        # We need to convert the values to float for json (it doesn't accept np.array, np.float, )
        d = {k: float(v) for k, v in d.items()}
        json.dump(d, f, indent=4)


def save_checkpoint(state, is_best, checkpoint, epoch_checkpoint = False):
    """Saves model and training parameters at checkpoint + 'last.pth.tar'. If is_best==True, also saves
    checkpoint + 'best.pth.tar'

    Args:
        state: (dict) contains model's state_dict, may contain other keys such as epoch, optimizer state_dict
        is_best: (bool) True if it is the best model seen till now
        checkpoint: (string) folder where parameters are to be saved
    """
    filepath = os.path.join(checkpoint, 'last.pth.tar')
    if not os.path.exists(checkpoint):
        print("Checkpoint Directory does not exist! Making directory {}".format(checkpoint))
        os.mkdir(checkpoint)
    else:
        print("Checkpoint Directory exists! ")
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(checkpoint, 'best.pth.tar'))
    if epoch_checkpoint == True:
        epoch_file = str(state['epoch']-1) + '.pth.tar'
        shutil.copyfile(filepath, os.path.join(checkpoint, epoch_file))





def load_checkpoint(checkpoint, model, optimizer=None):
    """Loads model parameters (state_dict) from file_path. If optimizer is provided, loads state_dict of
    optimizer assuming it is present in checkpoint.

    Args:
        checkpoint: (string) filename which needs to be loaded
        model: (torch.nn.Module) model for which the parameters are loaded
        optimizer: (torch.optim) optional: resume optimizer from checkpoint
    """
    try:

      if not os.path.exists(checkpoint):
        raise FileNotFoundError
    except FileNotFoundError:
      ("File doesn't exist {}".format(checkpoint))

    if torch.cuda.is_available():
        checkpoint = torch.load(checkpoint)
    else:
        # this helps avoid errors when loading single-GPU-trained weights onto CPU-model
        checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage)

    model.load_state_dict(checkpoint['state_dict'])

    if optimizer:
        optimizer.load_state_dict(checkpoint['optim_dict'])

    return checkpoint


class WarmUpLR(_LRScheduler):
    """warmup_training learning rate scheduler
    Args:
        optimizer: optimzier(e.g. SGD)
        total_iters: totoal_iters of warmup phase
    """

    def __init__(self, optimizer, total_iters, last_epoch=-1):
        self.total_iters = total_iters
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        """we will use the first m batches, and set the learning
        rate to base_lr * m / total_iters
        """
        return [base_lr * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs]

time: 0 ns (started: 2024-05-06 06:46:56 +06:00)


In [3]:
transform = {
        'train': torchvision.transforms.Compose([
            torchvision.transforms.Resize([64,64]), # Resizing the image as the VGG only take 224 x 244 as input size
            torchvision.transforms.RandomHorizontalFlip(), # Flip the data horizontally
            torchvision.transforms.RandomVerticalFlip(), # Flip the data horizontally
#             torchvision.transforms.CenterCrop(64),
            torchvision.transforms.ColorJitter(),
#                         torchvision.transforms.GaussianBlur(),
            torchvision.transforms.RandAugment(),
#             torchvision.transforms.AugMix(),
            #TODO if it is needed, add the random crop
#             torchvision.transforms.RandomHorizontalFlip(p=0.5),
            torchvision.transforms.ToTensor(),
#             transforms.Normalize(mean=(0.5,0.5,0.5), std=(0.5,0.5,0.5))
        ]),
        'test': torchvision.transforms.Compose([
            torchvision.transforms.Resize([64,64]),
#             transforms.RandomHorizontalFlip(),
            torchvision.transforms.ToTensor(),
#             transforms.Normalize(mean=(0.5,0.5,0.5), std=(0.5,0.5,0.5))
        ])
    }

time: 32 ms (started: 2024-05-06 06:46:56 +06:00)


In [4]:
train_data_dir = '.\\train_images'
# train_coco = 'images_thermal_train/coco.json'
val_train_dir = '.\\val_images'
# val_coco = 'images_thermal_val/coco.json'
batch_size = 128

train_data = torchvision.datasets.ImageFolder(root=train_data_dir, transform=transform['train'])
train_dataloader = torch.utils.data.DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)

val_data = torchvision.datasets.ImageFolder(root=val_train_dir, transform=transform['test'])
val_dataloader = torch.utils.data.DataLoader(dataset=val_data, batch_size=64, shuffle=True)

time: 1.03 s (started: 2024-05-06 06:46:56 +06:00)


In [5]:
class_dictionary = {1: 'person', 2: 'bike', 3: 'car', 4: 'motor', 6: 'bus',  7: 'train', 8: 'truck', 10: 'light', 
                   11: 'hydrant', 12: 'sign', 17: 'dog', 37: 'skateboard', 73: 'stroller', 77: 'scooter', 79: 'other vehicle',
                   80: 'negative'}

time: 0 ns (started: 2024-05-06 06:46:57 +06:00)


In [6]:
# i =1
# s = list(class_dictionary)
# img= train_dataset[i][0]
# bbox = train_dataset[i][1]['boxes']
# classes = train_dataset[i][1]['labels']
# # print(len(classes))
# k = 0
# for i in val_dataset:
#     img= i[0]
#     bbox = i[1]['boxes']
#     classes = i[1]['labels']
#     for a,b in zip(bbox, classes):
# #         print(int(b))
#         xmin, ymin, xmax, ymax = a
#         pt1 = (int(xmin), int(ymin))
#         pt2 = (int(xmax), int(ymax))
#         crop = img.permute(1, 2, 0).numpy()
#         crop = crop[int(ymin):int(ymax), int(xmin):int(xmax)]
#         # vool = b not in s
#         # if b:
#         #     b = 'other'
#         cv2.imwrite(".\\val_images\\{0}\\crop_{1}.png".format(b,a), 255*crop)
#     # crop = img[0][int(ymin):int(ymax), int(xmin):int(xmax)]
#     # plt.grid(False)
#     # plt.axis('off')
#     # plt.imshow(crop, cmap = 'gray')
#     # plt.savefig('foo'+str(k)+'.png', dpi = 20)
#     # k+=1    
#     # plt.show()

#     # bnd_img = cv2.rectangle(img.permute(1, 2, 0).numpy(),pt1, pt2,(0,0,0),1)
#     # bnd_img = cv2.putText(
#     #     bnd_img,
#     #     str(j),
#     #     (int(xmin), int(ymin) - 10),
#     #     fontFace = cv2.FONT_HERSHEY_SIMPLEX,
#     #     fontScale = 0.3,
#     #     color = (0, 255, 255),
#     #     thickness=1)
    
#     # plt.imshow(bnd_img, cmap = 'gray')

time: 16 ms (started: 2024-05-06 06:46:57 +06:00)


In [7]:
train_data[0][0].shape

torch.Size([3, 64, 64])

time: 16 ms (started: 2024-05-06 06:46:57 +06:00)


In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

class FocalLoss(nn.Module):
    def __init__(self, gamma=0, alpha=None, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        if isinstance(alpha,(float,int)): self.alpha = torch.Tensor([alpha,1-alpha])
        if isinstance(alpha,list): self.alpha = torch.Tensor(alpha)
        self.size_average = size_average

    def forward(self, input, target):
        if input.dim()>2:
            input = input.view(input.size(0),input.size(1),-1)  # N,C,H,W => N,C,H*W
            input = input.transpose(1,2)    # N,C,H*W => N,H*W,C
            input = input.contiguous().view(-1,input.size(2))   # N,H*W,C => N*H*W,C
        target = target.view(-1,1)

        logpt = F.log_softmax(input)
        logpt = logpt.gather(1,target)
        logpt = logpt.view(-1)
        pt = Variable(logpt.data.exp())

        if self.alpha is not None:
            if self.alpha.type()!=input.data.type():
                self.alpha = self.alpha.type_as(input.data)
            at = self.alpha.gather(0,target.data.view(-1))
            logpt = logpt * Variable(at)

        loss = -1 * (1-pt)**self.gamma * logpt
        if self.size_average: return loss.mean()
        else: return loss.sum()

time: 0 ns (started: 2024-05-06 06:46:57 +06:00)


In [9]:
def loss_kd_regularization(outputs, labels):
    """
    loss function for mannually-designed regularization: Tf-KD_{reg}
    """
    alpha = 0.95
    T = 20
    correct_prob = 0.99    # the probability for correct class in u(k)
    loss_CE = F.cross_entropy(outputs, labels)
#     loss_CE = FocalLoss(gamma=0)(outputs,labels)
    K = outputs.size(1)

    teacher_soft = torch.ones_like(outputs).to(device)
    teacher_soft = teacher_soft*(1-correct_prob)/(K-1)  # p^d(k)
    for i in range(outputs.shape[0]):
        teacher_soft[i, labels[i]] = correct_prob
    loss_soft_regu = torch.nn.KLDivLoss()(F.log_softmax(outputs, dim=1),
                                    F.softmax(teacher_soft/T, dim=1))*1

    KD_loss = (1. - alpha)*loss_CE + alpha*loss_soft_regu

    return KD_loss

time: 0 ns (started: 2024-05-06 06:46:57 +06:00)


In [10]:
def evaluate_kd(model, dataloader, device = 'cuda', crivice = 1):
    """Evaluate the model on `num_steps` batches.
    
    Args:
        model: (torch.nn.Module) the neural network
        loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch
        dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches data
        metrics: (dict) a dictionary of functions that compute a metric using the output and labels of each batch
        params: (Params) hyperparameters
        num_steps: (int) number of batches to train on, each of size params.batch_size
    """

    # set model to evaluation mode
    model.eval()
    total = 0
    correct = 0
    data_sc = []
    data_rc = []
    # compute metrics over the dataset
    for i, (data_batch, labels_batch) in enumerate(dataloader):
        # move to GPU if available
        data_batch, labels_batch = data_batch.to(device), labels_batch.to(device)
        # fetch the next evaluation batch
        data_batch, labels_batch = Variable(data_batch), Variable(labels_batch)

        # compute model output
        output_batch = model(data_batch)
        
        # loss = loss_fn_kd(output_batch, labels_batch, output_teacher_batch, params)
        loss = 0.0  # force validation loss to zero to reduce computation time
        _, predicted = output_batch.max(1)
        total += labels_batch.size(0)
        correct += predicted.eq(labels_batch).sum().item()
        predicted = predicted.cpu().data.numpy()
        labels_batch = labels_batch.cpu().data.numpy()
        data_sc.extend(predicted)
        data_rc.extend(labels_batch)
#         if crivice == 2:
#             break
#         break
#     print(data_rc)
    fscore = f1_score(data_rc, data_sc, average = 'weighted')
    precision = precision_score(data_rc, data_sc, average = 'weighted')
    recall = recall_score(data_rc, data_sc, average = 'weighted')
    kappa = cohen_kappa_score(data_rc, data_sc)
    C = classification_report(data_rc, data_sc)
    matrix = confusion_matrix(data_rc, data_sc)
    accuracy_class = matrix.diagonal()/matrix.sum(axis=1)
    
    print("Precision: ", precision)
    print("Recall: ", recall)
    print("Report: ", C)
    print("F1: ", fscore)
    print("Kappa: ", kappa)
    print("Matrix: ", matrix)
    print("Accuracy Class: ", accuracy_class)

    acc = 100. * correct / total
    logging.info("- Eval metrics, acc:{acc:.4f}, loss: {loss:.4f}".format(acc=acc, loss=loss))
    my_metric = {'accuracy': acc, 'loss': loss}
    #my_metric['accuracy'] = acc
    return my_metric

time: 0 ns (started: 2024-05-06 06:46:57 +06:00)


In [11]:
def train_and_evaluate_kd(model, train_dataloader, val_dataloader, optimizer,
                       loss_fn_kd, warmup_scheduler):
    """
    KD Train the model and evaluate every epoch.
    """
    # reload weights from restore_file if specified
#     if restore_file is not None:
#         restore_path = os.path.join(args.model_dir, args.restore_file + '.pth.tar')
#         logging.info("Restoring parameters from {}".format(restore_path))
#         load_checkpoint(restore_path, model, optimizer)

    # tensorboard setting
    log_dir = './tensorboard/'
    writer = SummaryWriter(log_dir=log_dir)
    best_val_acc = 0.0
#     teacher_model.eval()
#     teacher_acc = evaluate_kd(teacher_model, val_dataloader, params)
#     print(">>>>>>>>>The teacher accuracy: {}>>>>>>>>>".format(teacher_acc['accuracy']))
    model =model.to(device)
    scheduler = MultiStepLR(optimizer, milestones=[30, 60, 80], gamma=0.2)
    for epoch in range(100):
        if epoch > 0:   # 0 is the warm up epoch
            scheduler.step()
        logging.info("Epoch {}/{}, lr:{}".format(epoch + 1, 100, optimizer.param_groups[0]['lr']))
        print(epoch)
        # KD Train
        train_acc, train_loss = train_kd(model, optimizer, loss_fn_kd, train_dataloader, warmup_scheduler, epoch)
        # Evaluate
        val_metrics = evaluate_kd(model, val_dataloader)
#         print("val_metrics: ",val_metrics)
        val_acc = val_metrics['accuracy']
        is_best = val_acc>=best_val_acc

        # Save weights
        save_checkpoint({'epoch': epoch + 1,
                               'state_dict': model.state_dict(),
                               'optim_dict' : optimizer.state_dict()},
                               is_best=is_best,
                               checkpoint=".\\")
        print("Val Accuracy: ", val_acc)
        # If best_eval, best_save_path
        if is_best:
            logging.info("*********** Hurray ! Found new best accuracy *****************")
            best_val_acc = val_acc
            print("*********** Hurray ! Found new best accuracy *****************")
            print("Val Accuracy: ", val_acc)

            # Save best val metrics in a json file in the model directory
            file_name = "eval_best_result.json"
            best_json_path = os.path.join(".\\", file_name)
            save_dict_to_json(val_metrics, best_json_path)

        # Save latest val metrics in a json file in the model directory
        last_json_path = os.path.join(".\\", "eval_last_result.json")
        save_dict_to_json(val_metrics, last_json_path)

        # Tensorboard
        writer.add_scalar('Train_accuracy', train_acc, epoch)
        writer.add_scalar('Train_loss', train_loss, epoch)
        writer.add_scalar('Test_accuracy', val_metrics['accuracy'], epoch)
        writer.add_scalar('Test_loss', val_metrics['loss'], epoch)
#         export scalar data to JSON for external processing
    writer.close()


# Defining train_kd functions
def train_kd(model, optimizer, loss_fn_kd, dataloader, warmup_scheduler, epoch):
    """
    KD Train the model on `num_steps` batches
    """
    # set model to training mode
    model.train()
#     teacher_model.eval()
    loss_avg = RunningAverage()
    losses = AverageMeter()
    total = 0
    correct = 0
    # Use tqdm for progress bar
    with tqdm(total=len(dataloader)) as t:
        for i, (train_batch, labels_batch) in enumerate(dataloader):
            if epoch<=0:
                warmup_scheduler.step()

            train_batch, labels_batch = train_batch.to(device), labels_batch.to(device)
            # convert to torch Variables
            train_batch, labels_batch = Variable(train_batch), Variable(labels_batch)

            # compute model output, fetch teacher output, and compute KD loss
            output_batch = model(train_batch)

            # get one batch output from teacher model
#             output_teacher_batch = teacher_model(train_batch).to(device)
#             output_teacher_batch = Variable(output_teacher_batch, requires_grad=False)

            loss = loss_fn_kd(output_batch, labels_batch)

            # clear previous gradients, compute gradients of all variables wrt loss
            optimizer.zero_grad()
            loss.backward()

            # performs updates using calculated gradients
            optimizer.step()

            _, predicted = output_batch.max(1)
            total += labels_batch.size(0)
            correct += predicted.eq(labels_batch).sum().item()
            # update the average loss
            loss_avg.update(loss.data)
            losses.update(loss.item(), train_batch.size(0))

            t.set_postfix(loss='{:05.3f}'.format(loss_avg()), lr='{:05.6f}'.format(optimizer.param_groups[0]['lr']))
            t.update()

    acc = 100.*correct/total
    logging.info("- Train accuracy: {acc:.4f}, training loss: {loss:.4f}".format(acc = acc, loss = losses.avg))
    return acc, losses.avg

time: 0 ns (started: 2024-05-06 06:46:57 +06:00)


In [14]:
model = torchvision.models.resnet18(weights = "ResNet18_Weights.IMAGENET1K_V1")
model.fc = torch.nn.Linear(512, 16)

time: 188 ms (started: 2024-05-06 06:47:10 +06:00)


In [15]:
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

time: 0 ns (started: 2024-05-06 06:47:11 +06:00)


In [None]:
import warnings
warnings.filterwarnings("ignore") 

iter_per_epoch = len(train_dataloader)
optimizer = torch.optim.SGD(model.parameters(), lr= 0.1 * (batch_size / 128), momentum=0.9,
                                  weight_decay=5e-4)
warmup_scheduler = WarmUpLR(optimizer,iter_per_epoch *1)  # warmup the learning rate in the first epoch
loss = loss_kd_regularization
train_and_evaluate_kd(model, train_dataloader, val_dataloader, optimizer, loss, warmup_scheduler)

0


100%|█████████████████████████████████████████████████████| 1393/1393 [05:06<00:00,  4.54it/s, loss=0.090, lr=0.100000]


Precision:  0.8146182076315341
Recall:  0.8054674685620558
Report:                precision    recall  f1-score   support

           0       0.95      0.78      0.86      4465
           1       0.77      0.66      0.71      1945
           2       0.00      0.00      0.00        93
           3       0.71      0.89      0.79      2384
           4       0.28      0.65      0.40       170
           5       0.79      0.96      0.87      7101
           6       0.00      0.00      0.00         3
           7       0.00      0.00      0.00        55
           8       0.24      0.07      0.11       179
           9       0.00      0.00      0.00         6
          10       0.00      0.00      0.00        63
          11       0.00      0.00      0.00        46
          12       0.98      0.48      0.65      1780

    accuracy                           0.81     18290
   macro avg       0.36      0.35      0.34     18290
weighted avg       0.81      0.81      0.79     18290

F1:  0.7938

100%|█████████████████████████████████████████████████████| 1393/1393 [05:07<00:00,  4.52it/s, loss=0.078, lr=0.100000]


Precision:  0.7886068664822813
Recall:  0.7661016949152543
Report:                precision    recall  f1-score   support

           0       0.97      0.65      0.78      4465
           1       0.61      0.57      0.59      1945
           2       0.30      0.14      0.19        93
           3       0.60      0.94      0.73      2384
           4       0.55      0.44      0.49       170
           5       0.85      0.89      0.87      7101
           6       0.00      0.00      0.00         3
           7       1.00      0.02      0.04        55
           8       0.20      0.26      0.23       179
           9       0.00      0.00      0.00         6
          10       0.00      0.00      0.00        63
          11       0.20      0.02      0.04        46
          12       0.71      0.72      0.71      1780

    accuracy                           0.77     18290
   macro avg       0.46      0.36      0.36     18290
weighted avg       0.79      0.77      0.76     18290

F1:  0.7627

100%|█████████████████████████████████████████████████████| 1393/1393 [05:06<00:00,  4.55it/s, loss=0.077, lr=0.100000]


Precision:  0.766154588289262
Recall:  0.6945872061235648
Report:                precision    recall  f1-score   support

           0       0.63      0.92      0.75      4465
           1       0.44      0.88      0.59      1945
           2       0.40      0.06      0.11        93
           3       0.81      0.57      0.67      2384
           4       0.35      0.45      0.40       170
           5       0.92      0.64      0.75      7101
           6       0.00      0.00      0.00         3
           7       0.33      0.02      0.03        55
           8       0.20      0.12      0.15       179
           9       0.00      0.00      0.00         6
          10       1.00      0.03      0.06        63
          11       0.00      0.00      0.00        46
          12       0.92      0.48      0.63      1780

    accuracy                           0.69     18290
   macro avg       0.46      0.32      0.32     18290
weighted avg       0.77      0.69      0.69     18290

F1:  0.69396

100%|█████████████████████████████████████████████████████| 1393/1393 [05:06<00:00,  4.54it/s, loss=0.079, lr=0.100000]


Precision:  0.7304158693688233
Recall:  0.6370694368507381
Report:                precision    recall  f1-score   support

           0       0.53      0.97      0.69      4465
           1       0.59      0.81      0.68      1945
           2       0.14      0.16      0.15        93
           3       0.75      0.55      0.63      2384
           4       0.14      0.60      0.23       170
           5       0.92      0.49      0.64      7101
           6       0.00      0.00      0.00         3
           7       0.00      0.00      0.00        55
           8       0.64      0.04      0.07       179
           9       0.00      0.00      0.00         6
          10       0.00      0.00      0.00        63
          11       0.00      0.00      0.00        46
          12       0.75      0.45      0.56      1780

    accuracy                           0.64     18290
   macro avg       0.34      0.31      0.28     18290
weighted avg       0.73      0.64      0.63     18290

F1:  0.6311

100%|█████████████████████████████████████████████████████| 1393/1393 [05:06<00:00,  4.55it/s, loss=0.079, lr=0.100000]


Precision:  0.7221160715606795
Recall:  0.5420448332422089
Report:                precision    recall  f1-score   support

           0       0.96      0.29      0.44      4465
           1       0.93      0.30      0.45      1945
           2       0.00      0.00      0.00        93
           3       0.34      0.95      0.50      2384
           4       0.84      0.25      0.38       170
           5       0.80      0.68      0.73      7101
           6       0.00      0.00      0.00         3
           7       0.00      0.00      0.00        55
           8       0.11      0.01      0.01       179
           9       0.00      0.00      0.00         6
          10       0.00      0.00      0.00        63
          11       0.00      0.00      0.00        46
          12       0.26      0.52      0.35      1780

    accuracy                           0.54     18290
   macro avg       0.33      0.23      0.22     18290
weighted avg       0.72      0.54      0.54     18290

F1:  0.5436

100%|█████████████████████████████████████████████████████| 1393/1393 [05:06<00:00,  4.54it/s, loss=0.079, lr=0.100000]


Precision:  0.8059550752066059
Recall:  0.809896118097321
Report:                precision    recall  f1-score   support

           0       0.84      0.93      0.88      4465
           1       0.86      0.65      0.74      1945
           2       0.75      0.06      0.12        93
           3       0.76      0.78      0.77      2384
           4       0.75      0.28      0.40       170
           5       0.80      0.93      0.86      7101
           6       0.00      0.00      0.00         3
           7       0.00      0.00      0.00        55
           8       0.25      0.15      0.19       179
           9       0.00      0.00      0.00         6
          10       0.00      0.00      0.00        63
          11       0.00      0.00      0.00        46
          12       0.91      0.46      0.61      1780

    accuracy                           0.81     18290
   macro avg       0.45      0.33      0.35     18290
weighted avg       0.81      0.81      0.79     18290

F1:  0.79360

100%|█████████████████████████████████████████████████████| 1393/1393 [05:06<00:00,  4.55it/s, loss=0.079, lr=0.100000]


Precision:  0.7712826463040708
Recall:  0.7523236741388737
Report:                precision    recall  f1-score   support

           0       0.89      0.82      0.85      4465
           1       0.73      0.72      0.73      1945
           2       0.00      0.00      0.00        93
           3       0.83      0.59      0.69      2384
           4       0.38      0.62      0.47       170
           5       0.69      0.97      0.81      7101
           6       0.00      0.00      0.00         3
           7       1.00      0.02      0.04        55
           8       0.43      0.02      0.03       179
           9       0.00      0.00      0.00         6
          10       0.00      0.00      0.00        63
          11       0.00      0.00      0.00        46
          12       0.92      0.14      0.25      1780

    accuracy                           0.75     18290
   macro avg       0.45      0.30      0.30     18290
weighted avg       0.77      0.75      0.72     18290

F1:  0.7188

100%|█████████████████████████████████████████████████████| 1393/1393 [05:06<00:00,  4.55it/s, loss=0.079, lr=0.100000]


Precision:  0.7760901813270484
Recall:  0.733624931656643
Report:                precision    recall  f1-score   support

           0       0.93      0.74      0.82      4465
           1       0.80      0.39      0.53      1945
           2       0.75      0.03      0.06        93
           3       0.49      0.94      0.64      2384
           4       0.45      0.12      0.19       170
           5       0.77      0.92      0.84      7101
           6       0.00      0.00      0.00         3
           7       0.00      0.00      0.00        55
           8       0.27      0.17      0.21       179
           9       0.00      0.00      0.00         6
          10       0.00      0.00      0.00        63
          11       0.00      0.00      0.00        46
          12       0.94      0.30      0.45      1780

    accuracy                           0.73     18290
   macro avg       0.42      0.28      0.29     18290
weighted avg       0.78      0.73      0.71     18290

F1:  0.71417

100%|█████████████████████████████████████████████████████| 1393/1393 [05:04<00:00,  4.57it/s, loss=0.079, lr=0.100000]


Precision:  0.7964868918727779
Recall:  0.7911427009294697
Report:                precision    recall  f1-score   support

           0       0.81      0.92      0.86      4465
           1       0.77      0.57      0.65      1945
           2       0.68      0.16      0.26        93
           3       0.60      0.89      0.72      2384
           4       0.55      0.54      0.54       170
           5       0.91      0.83      0.87      7101
           6       0.00      0.00      0.00         3
           7       0.00      0.00      0.00        55
           8       0.75      0.07      0.12       179
           9       0.00      0.00      0.00         6
          10       0.00      0.00      0.00        63
          11       0.00      0.00      0.00        46
          12       0.72      0.63      0.67      1780

    accuracy                           0.79     18290
   macro avg       0.45      0.35      0.36     18290
weighted avg       0.80      0.79      0.78     18290

F1:  0.7830

100%|█████████████████████████████████████████████████████| 1393/1393 [05:05<00:00,  4.56it/s, loss=0.079, lr=0.100000]


Precision:  0.7745433454829244
Recall:  0.6892837616183707
Report:                precision    recall  f1-score   support

           0       0.96      0.67      0.79      4465
           1       0.66      0.52      0.58      1945
           2       0.40      0.02      0.04        93
           3       0.48      0.96      0.64      2384
           4       0.82      0.35      0.49       170
           5       0.92      0.71      0.80      7101
           6       0.00      0.00      0.00         3
           7       1.00      0.04      0.07        55
           8       0.12      0.22      0.15       179
           9       0.00      0.00      0.00         6
          10       0.00      0.00      0.00        63
          11       0.00      0.00      0.00        46
          12       0.38      0.64      0.48      1780

    accuracy                           0.69     18290
   macro avg       0.44      0.32      0.31     18290
weighted avg       0.77      0.69      0.70     18290

F1:  0.7035

100%|█████████████████████████████████████████████████████| 1393/1393 [05:05<00:00,  4.56it/s, loss=0.079, lr=0.100000]


Precision:  0.824177195728374
Recall:  0.8185347184253691
Report:                precision    recall  f1-score   support

           0       0.89      0.91      0.90      4465
           1       0.77      0.71      0.74      1945
           2       0.78      0.15      0.25        93
           3       0.64      0.84      0.73      2384
           4       0.39      0.64      0.48       170
           5       0.87      0.91      0.89      7101
           6       0.00      0.00      0.00         3
           7       1.00      0.04      0.07        55
           8       0.62      0.10      0.17       179
           9       0.00      0.00      0.00         6
          10       0.00      0.00      0.00        63
          11       0.00      0.00      0.00        46
          12       0.91      0.52      0.66      1780

    accuracy                           0.82     18290
   macro avg       0.53      0.37      0.38     18290
weighted avg       0.82      0.82      0.81     18290

F1:  0.80966

100%|█████████████████████████████████████████████████████| 1393/1393 [05:05<00:00,  4.56it/s, loss=0.079, lr=0.100000]


Precision:  0.7688717960717066
Recall:  0.7220885729907053
Report:                precision    recall  f1-score   support

           0       0.93      0.69      0.79      4465
           1       0.72      0.27      0.40      1945
           2       0.34      0.14      0.20        93
           3       0.65      0.81      0.72      2384
           4       0.54      0.40      0.46       170
           5       0.87      0.88      0.87      7101
           6       0.00      0.00      0.00         3
           7       0.00      0.00      0.00        55
           8       0.38      0.03      0.05       179
           9       0.00      0.00      0.00         6
          10       0.00      0.00      0.00        63
          11       0.00      0.00      0.00        46
          12       0.34      0.74      0.46      1780

    accuracy                           0.72     18290
   macro avg       0.37      0.30      0.30     18290
weighted avg       0.77      0.72      0.72     18290

F1:  0.7200

100%|█████████████████████████████████████████████████████| 1393/1393 [05:05<00:00,  4.57it/s, loss=0.078, lr=0.100000]


Precision:  0.7738604786651786
Recall:  0.7162930563149262
Report:                precision    recall  f1-score   support

           0       0.67      0.95      0.79      4465
           1       0.48      0.92      0.63      1945
           2       0.75      0.06      0.12        93
           3       0.81      0.50      0.62      2384
           4       0.37      0.57      0.45       170
           5       0.89      0.74      0.81      7101
           6       0.00      0.00      0.00         3
           7       1.00      0.02      0.04        55
           8       0.46      0.17      0.25       179
           9       0.00      0.00      0.00         6
          10       0.00      0.00      0.00        63
          11       0.00      0.00      0.00        46
          12       0.97      0.28      0.44      1780

    accuracy                           0.72     18290
   macro avg       0.49      0.32      0.32     18290
weighted avg       0.77      0.72      0.70     18290

F1:  0.7026

100%|█████████████████████████████████████████████████████| 1393/1393 [05:05<00:00,  4.56it/s, loss=0.078, lr=0.100000]


Precision:  0.6439728986841481
Recall:  0.447840349917988
Report:                precision    recall  f1-score   support

           0       0.37      0.93      0.53      4465
           1       0.82      0.43      0.56      1945
           2       0.00      0.00      0.00        93
           3       0.69      0.48      0.57      2384
           4       0.43      0.09      0.15       170
           5       0.91      0.21      0.34      7101
           6       0.00      0.00      0.00         3
           7       0.00      0.00      0.00        55
           8       0.00      0.00      0.00       179
           9       0.00      0.00      0.00         6
          10       0.00      0.00      0.00        63
          11       0.00      0.00      0.00        46
          12       0.21      0.30      0.25      1780

    accuracy                           0.45     18290
   macro avg       0.26      0.19      0.18     18290
weighted avg       0.64      0.45      0.42     18290

F1:  0.42035

100%|█████████████████████████████████████████████████████| 1393/1393 [05:06<00:00,  4.54it/s, loss=0.078, lr=0.100000]


Precision:  0.8043649847433892
Recall:  0.792837616183707
Report:                precision    recall  f1-score   support

           0       0.82      0.91      0.87      4465
           1       0.83      0.70      0.76      1945
           2       1.00      0.04      0.08        93
           3       0.75      0.83      0.79      2384
           4       0.17      0.66      0.27       170
           5       0.83      0.90      0.86      7101
           6       0.00      0.00      0.00         3
           7       0.00      0.00      0.00        55
           8       0.71      0.06      0.10       179
           9       0.00      0.00      0.00         6
          10       0.00      0.00      0.00        63
          11       0.00      0.00      0.00        46
          12       0.84      0.32      0.46      1780

    accuracy                           0.79     18290
   macro avg       0.46      0.34      0.32     18290
weighted avg       0.80      0.79      0.78     18290

F1:  0.77887

100%|█████████████████████████████████████████████████████| 1393/1393 [05:05<00:00,  4.57it/s, loss=0.078, lr=0.100000]


Precision:  0.7857272369808905
Recall:  0.7917987971569164
Report:                precision    recall  f1-score   support

           0       0.79      0.91      0.84      4465
           1       0.77      0.63      0.69      1945
           2       0.39      0.26      0.31        93
           3       0.77      0.77      0.77      2384
           4       0.52      0.56      0.54       170
           5       0.86      0.86      0.86      7101
           6       0.00      0.00      0.00         3
           7       0.00      0.00      0.00        55
           8       1.00      0.01      0.01       179
           9       0.00      0.00      0.00         6
          10       0.00      0.00      0.00        63
          11       0.00      0.00      0.00        46
          12       0.64      0.65      0.64      1780

    accuracy                           0.79     18290
   macro avg       0.44      0.36      0.36     18290
weighted avg       0.79      0.79      0.78     18290

F1:  0.7821

100%|█████████████████████████████████████████████████████| 1393/1393 [05:06<00:00,  4.55it/s, loss=0.078, lr=0.100000]


Precision:  0.7975001586560162
Recall:  0.7997813012575178
Report:                precision    recall  f1-score   support

           0       0.84      0.92      0.88      4465
           1       0.67      0.81      0.73      1945
           2       0.62      0.05      0.10        93
           3       0.80      0.71      0.75      2384
           4       0.64      0.42      0.51       170
           5       0.80      0.93      0.86      7101
           6       0.00      0.00      0.00         3
           7       0.00      0.00      0.00        55
           8       0.05      0.01      0.01       179
           9       0.00      0.00      0.00         6
          10       0.00      0.00      0.00        63
          11       0.00      0.00      0.00        46
          12       0.98      0.31      0.48      1780

    accuracy                           0.80     18290
   macro avg       0.42      0.32      0.33     18290
weighted avg       0.80      0.80      0.78     18290

F1:  0.7771

100%|█████████████████████████████████████████████████████| 1393/1393 [05:05<00:00,  4.56it/s, loss=0.078, lr=0.100000]


Precision:  0.8042141129566406
Recall:  0.791853471842537
Report:                precision    recall  f1-score   support

           0       0.94      0.78      0.86      4465
           1       0.85      0.60      0.70      1945
           2       0.75      0.03      0.06        93
           3       0.75      0.83      0.79      2384
           4       0.55      0.39      0.46       170
           5       0.73      0.96      0.83      7101
           6       0.00      0.00      0.00         3
           7       0.33      0.02      0.03        55
           8       0.19      0.15      0.17       179
           9       0.00      0.00      0.00         6
          10       0.00      0.00      0.00        63
          11       0.00      0.00      0.00        46
          12       0.92      0.51      0.65      1780

    accuracy                           0.79     18290
   macro avg       0.46      0.33      0.35     18290
weighted avg       0.80      0.79      0.78     18290

F1:  0.77968

100%|█████████████████████████████████████████████████████| 1393/1393 [05:06<00:00,  4.55it/s, loss=0.078, lr=0.100000]


Precision:  0.8229810105006946
Recall:  0.8291962821213777
Report:                precision    recall  f1-score   support

           0       0.94      0.89      0.91      4465
           1       0.73      0.79      0.76      1945
           2       0.00      0.00      0.00        93
           3       0.75      0.80      0.78      2384
           4       0.46      0.61      0.53       170
           5       0.83      0.95      0.89      7101
           6       0.00      0.00      0.00         3
           7       0.00      0.00      0.00        55
           8       0.37      0.16      0.23       179
           9       0.00      0.00      0.00         6
          10       0.00      0.00      0.00        63
          11       0.00      0.00      0.00        46
          12       0.92      0.49      0.64      1780

    accuracy                           0.83     18290
   macro avg       0.38      0.36      0.36     18290
weighted avg       0.82      0.83      0.82     18290

F1:  0.8169

100%|█████████████████████████████████████████████████████| 1393/1393 [05:05<00:00,  4.56it/s, loss=0.078, lr=0.100000]


Precision:  0.8029830420081336
Recall:  0.7724986331328595
Report:                precision    recall  f1-score   support

           0       0.70      0.98      0.82      4465
           1       0.85      0.62      0.72      1945
           2       0.60      0.40      0.48        93
           3       0.76      0.81      0.78      2384
           4       0.20      0.64      0.30       170
           5       0.87      0.84      0.86      7101
           6       0.00      0.00      0.00         3
           7       1.00      0.02      0.04        55
           8       0.31      0.22      0.26       179
           9       0.00      0.00      0.00         6
          10       0.00      0.00      0.00        63
          11       0.00      0.00      0.00        46
          12       0.97      0.24      0.39      1780

    accuracy                           0.77     18290
   macro avg       0.48      0.37      0.36     18290
weighted avg       0.80      0.77      0.76     18290

F1:  0.7561

100%|█████████████████████████████████████████████████████| 1393/1393 [05:06<00:00,  4.55it/s, loss=0.078, lr=0.100000]


Precision:  0.7805690652861229
Recall:  0.7241662110442865
Report:                precision    recall  f1-score   support

           0       0.98      0.55      0.70      4465
           1       0.86      0.58      0.70      1945
           2       0.81      0.18      0.30        93
           3       0.72      0.79      0.76      2384
           4       0.41      0.65      0.50       170
           5       0.64      0.98      0.78      7101
           6       0.00      0.00      0.00         3
           7       0.43      0.11      0.17        55
           8       0.13      0.03      0.05       179
           9       0.00      0.00      0.00         6
          10       0.00      0.00      0.00        63
          11       0.00      0.00      0.00        46
          12       0.98      0.38      0.55      1780

    accuracy                           0.72     18290
   macro avg       0.46      0.33      0.35     18290
weighted avg       0.78      0.72      0.71     18290

F1:  0.7059

100%|█████████████████████████████████████████████████████| 1393/1393 [05:05<00:00,  4.57it/s, loss=0.078, lr=0.100000]


Precision:  0.806942062612741
Recall:  0.8133406232914161
Report:                precision    recall  f1-score   support

           0       0.85      0.91      0.88      4465
           1       0.75      0.72      0.74      1945
           2       0.31      0.09      0.13        93
           3       0.66      0.89      0.76      2384
           4       0.53      0.51      0.52       170
           5       0.89      0.85      0.87      7101
           6       0.00      0.00      0.00         3
           7       0.00      0.00      0.00        55
           8       0.46      0.03      0.06       179
           9       0.00      0.00      0.00         6
          10       0.00      0.00      0.00        63
          11       0.00      0.00      0.00        46
          12       0.78      0.64      0.71      1780

    accuracy                           0.81     18290
   macro avg       0.40      0.36      0.36     18290
weighted avg       0.81      0.81      0.80     18290

F1:  0.80480

100%|█████████████████████████████████████████████████████| 1393/1393 [05:05<00:00,  4.55it/s, loss=0.078, lr=0.100000]


Precision:  0.7806747671236824
Recall:  0.779606342263532
Report:                precision    recall  f1-score   support

           0       0.84      0.84      0.84      4465
           1       0.74      0.68      0.71      1945
           2       0.74      0.15      0.25        93
           3       0.74      0.78      0.76      2384
           4       0.60      0.44      0.50       170
           5       0.87      0.84      0.85      7101
           6       0.00      0.00      0.00         3
           7       0.00      0.00      0.00        55
           8       0.20      0.17      0.18       179
           9       0.00      0.00      0.00         6
          10       0.00      0.00      0.00        63
          11       0.00      0.00      0.00        46
          12       0.52      0.69      0.59      1780

    accuracy                           0.78     18290
   macro avg       0.40      0.35      0.36     18290
weighted avg       0.78      0.78      0.78     18290

F1:  0.77742

100%|█████████████████████████████████████████████████████| 1393/1393 [05:05<00:00,  4.56it/s, loss=0.078, lr=0.100000]


Precision:  0.7660675957767155
Recall:  0.7393657736468016
Report:                precision    recall  f1-score   support

           0       0.92      0.77      0.84      4465
           1       0.44      0.86      0.58      1945
           2       0.29      0.26      0.27        93
           3       0.60      0.46      0.52      2384
           4       0.42      0.45      0.44       170
           5       0.86      0.86      0.86      7101
           6       0.00      0.00      0.00         3
           7       0.45      0.09      0.15        55
           8       0.24      0.18      0.21       179
           9       0.00      0.00      0.00         6
          10       0.00      0.00      0.00        63
          11       0.00      0.00      0.00        46
          12       0.76      0.60      0.67      1780

    accuracy                           0.74     18290
   macro avg       0.38      0.35      0.35     18290
weighted avg       0.77      0.74      0.74     18290

F1:  0.7412

100%|█████████████████████████████████████████████████████| 1393/1393 [05:03<00:00,  4.59it/s, loss=0.078, lr=0.100000]


Precision:  0.8223536864168983
Recall:  0.8052487698195735
Report:                precision    recall  f1-score   support

           0       0.90      0.84      0.87      4465
           1       0.82      0.71      0.76      1945
           2       0.75      0.13      0.22        93
           3       0.75      0.85      0.79      2384
           4       0.21      0.66      0.32       170
           5       0.80      0.95      0.87      7101
           6       0.00      0.00      0.00         3
           7       0.62      0.09      0.16        55
           8       0.62      0.04      0.08       179
           9       0.00      0.00      0.00         6
          10       0.00      0.00      0.00        63
          11       0.00      0.00      0.00        46
          12       0.98      0.39      0.56      1780

    accuracy                           0.81     18290
   macro avg       0.50      0.36      0.36     18290
weighted avg       0.82      0.81      0.79     18290

F1:  0.7926

100%|█████████████████████████████████████████████████████| 1393/1393 [05:05<00:00,  4.56it/s, loss=0.078, lr=0.100000]


Precision:  0.7863710302065925
Recall:  0.7697648988518316
Report:                precision    recall  f1-score   support

           0       0.94      0.77      0.85      4465
           1       0.63      0.64      0.63      1945
           2       0.60      0.16      0.25        93
           3       0.58      0.92      0.71      2384
           4       0.58      0.55      0.57       170
           5       0.81      0.89      0.85      7101
           6       0.00      0.00      0.00         3
           7       0.00      0.00      0.00        55
           8       0.45      0.08      0.13       179
           9       0.00      0.00      0.00         6
          10       0.00      0.00      0.00        63
          11       0.00      0.00      0.00        46
          12       0.90      0.42      0.57      1780

    accuracy                           0.77     18290
   macro avg       0.42      0.34      0.35     18290
weighted avg       0.79      0.77      0.76     18290

F1:  0.7599

100%|█████████████████████████████████████████████████████| 1393/1393 [05:03<00:00,  4.58it/s, loss=0.078, lr=0.100000]


Precision:  0.8007572567023751
Recall:  0.7533078184800437
Report:                precision    recall  f1-score   support

           0       0.91      0.81      0.86      4465
           1       0.85      0.61      0.71      1945
           2       0.74      0.15      0.25        93
           3       0.61      0.88      0.72      2384
           4       0.14      0.76      0.24       170
           5       0.78      0.89      0.83      7101
           6       0.00      0.00      0.00         3
           7       0.25      0.05      0.09        55
           8       0.50      0.02      0.03       179
           9       0.00      0.00      0.00         6
          10       0.00      0.00      0.00        63
          11       0.00      0.00      0.00        46
          12       0.97      0.19      0.32      1780

    accuracy                           0.75     18290
   macro avg       0.44      0.34      0.31     18290
weighted avg       0.80      0.75      0.74     18290

F1:  0.7384

100%|█████████████████████████████████████████████████████| 1393/1393 [05:07<00:00,  4.54it/s, loss=0.078, lr=0.100000]


Precision:  0.7737531657393886
Recall:  0.7333515582285401
Report:                precision    recall  f1-score   support

           0       0.71      0.90      0.79      4465
           1       0.46      0.84      0.60      1945
           2       0.39      0.15      0.22        93
           3       0.73      0.72      0.72      2384
           4       0.51      0.29      0.37       170
           5       0.91      0.73      0.81      7101
           6       0.00      0.00      0.00         3
           7       0.33      0.13      0.18        55
           8       0.75      0.02      0.03       179
           9       0.00      0.00      0.00         6
          10       0.00      0.00      0.00        63
          11       0.00      0.00      0.00        46
          12       0.90      0.46      0.61      1780

    accuracy                           0.73     18290
   macro avg       0.44      0.33      0.33     18290
weighted avg       0.77      0.73      0.73     18290

F1:  0.7298

100%|█████████████████████████████████████████████████████| 1393/1393 [05:07<00:00,  4.54it/s, loss=0.078, lr=0.100000]


Precision:  0.825393690585905
Recall:  0.8293603061782395
Report:                precision    recall  f1-score   support

           0       0.94      0.88      0.91      4465
           1       0.77      0.69      0.73      1945
           2       0.74      0.37      0.49        93
           3       0.69      0.84      0.76      2384
           4       0.43      0.48      0.45       170
           5       0.85      0.94      0.89      7101
           6       0.00      0.00      0.00         3
           7       1.00      0.04      0.07        55
           8       0.41      0.13      0.20       179
           9       0.00      0.00      0.00         6
          10       0.00      0.00      0.00        63
          11       0.00      0.00      0.00        46
          12       0.80      0.62      0.70      1780

    accuracy                           0.83     18290
   macro avg       0.51      0.38      0.40     18290
weighted avg       0.83      0.83      0.82     18290

F1:  0.82146

100%|█████████████████████████████████████████████████████| 1393/1393 [05:06<00:00,  4.54it/s, loss=0.078, lr=0.100000]


Precision:  0.8372110998851761
Recall:  0.8358119190814652
Report:                precision    recall  f1-score   support

           0       0.87      0.97      0.92      4465
           1       0.88      0.64      0.74      1945
           2       0.87      0.28      0.42        93
           3       0.71      0.89      0.79      2384
           4       0.37      0.64      0.47       170
           5       0.87      0.92      0.90      7101
           6       0.00      0.00      0.00         3
           7       0.82      0.16      0.27        55
           8       0.56      0.11      0.18       179
           9       0.00      0.00      0.00         6
          10       0.00      0.00      0.00        63
          11       0.00      0.00      0.00        46
          12       0.88      0.50      0.64      1780

    accuracy                           0.84     18290
   macro avg       0.52      0.39      0.41     18290
weighted avg       0.84      0.84      0.82     18290

F1:  0.8237

100%|█████████████████████████████████████████████████████| 1393/1393 [05:06<00:00,  4.55it/s, loss=0.075, lr=0.020000]


Precision:  0.839295568969346
Recall:  0.8206670311645708
Report:                precision    recall  f1-score   support

           0       0.79      0.98      0.87      4465
           1       0.85      0.77      0.81      1945
           2       0.61      0.44      0.51        93
           3       0.78      0.84      0.81      2384
           4       0.20      0.79      0.32       170
           5       0.91      0.84      0.87      7101
           6       0.00      0.00      0.00         3
           7       0.47      0.35      0.40        55
           8       0.71      0.11      0.19       179
           9       0.00      0.00      0.00         6
          10       0.00      0.00      0.00        63
          11       0.00      0.00      0.00        46
          12       0.92      0.56      0.69      1780

    accuracy                           0.82     18290
   macro avg       0.48      0.44      0.42     18290
weighted avg       0.84      0.82      0.82     18290

F1:  0.81843

100%|█████████████████████████████████████████████████████| 1393/1393 [05:05<00:00,  4.56it/s, loss=0.074, lr=0.020000]


Precision:  0.8545154294712972
Recall:  0.8524330235101149
Report:                precision    recall  f1-score   support

           0       0.92      0.92      0.92      4465
           1       0.73      0.80      0.76      1945
           2       0.72      0.47      0.57        93
           3       0.76      0.87      0.81      2384
           4       0.46      0.74      0.57       170
           5       0.89      0.94      0.91      7101
           6       0.00      0.00      0.00         3
           7       0.57      0.45      0.51        55
           8       0.54      0.19      0.28       179
           9       0.00      0.00      0.00         6
          10       0.00      0.00      0.00        63
          11       0.00      0.00      0.00        46
          12       0.98      0.53      0.69      1780

    accuracy                           0.85     18290
   macro avg       0.50      0.46      0.46     18290
weighted avg       0.85      0.85      0.84     18290

F1:  0.8449

100%|█████████████████████████████████████████████████████| 1393/1393 [05:06<00:00,  4.55it/s, loss=0.074, lr=0.020000]


Precision:  0.8629428200571168
Recall:  0.8597594313832696
Report:                precision    recall  f1-score   support

           0       0.89      0.97      0.93      4465
           1       0.86      0.75      0.80      1945
           2       0.58      0.58      0.58        93
           3       0.80      0.87      0.83      2384
           4       0.39      0.81      0.52       170
           5       0.88      0.95      0.91      7101
           6       0.00      0.00      0.00         3
           7       0.46      0.47      0.46        55
           8       0.62      0.20      0.30       179
           9       0.00      0.00      0.00         6
          10       0.00      0.00      0.00        63
          11       0.00      0.00      0.00        46
          12       0.98      0.50      0.66      1780

    accuracy                           0.86     18290
   macro avg       0.50      0.47      0.46     18290
weighted avg       0.86      0.86      0.85     18290

F1:  0.8502

100%|█████████████████████████████████████████████████████| 1393/1393 [05:05<00:00,  4.55it/s, loss=0.074, lr=0.020000]


Precision:  0.8668307186766859
Recall:  0.8679606342263532
Report:                precision    recall  f1-score   support

           0       0.95      0.91      0.93      4465
           1       0.81      0.79      0.80      1945
           2       0.89      0.35      0.51        93
           3       0.74      0.90      0.82      2384
           4       0.68      0.65      0.66       170
           5       0.88      0.95      0.91      7101
           6       0.00      0.00      0.00         3
           7       0.48      0.42      0.45        55
           8       0.60      0.28      0.39       179
           9       0.00      0.00      0.00         6
          10       0.00      0.00      0.00        63
          11       0.00      0.00      0.00        46
          12       0.94      0.66      0.77      1780

    accuracy                           0.87     18290
   macro avg       0.54      0.45      0.48     18290
weighted avg       0.87      0.87      0.86     18290

F1:  0.8620

100%|█████████████████████████████████████████████████████| 1393/1393 [05:07<00:00,  4.53it/s, loss=0.074, lr=0.020000]


Precision:  0.8466406557408154
Recall:  0.8420448332422089
Report:                precision    recall  f1-score   support

           0       0.79      0.98      0.87      4465
           1       0.82      0.78      0.80      1945
           2       0.70      0.34      0.46        93
           3       0.84      0.80      0.82      2384
           4       0.34      0.74      0.47       170
           5       0.92      0.87      0.89      7101
           6       0.00      0.00      0.00         3
           7       0.41      0.35      0.38        55
           8       0.41      0.34      0.37       179
           9       0.00      0.00      0.00         6
          10       0.00      0.00      0.00        63
          11       0.00      0.00      0.00        46
          12       0.89      0.69      0.78      1780

    accuracy                           0.84     18290
   macro avg       0.47      0.45      0.45     18290
weighted avg       0.85      0.84      0.84     18290

F1:  0.8391

100%|█████████████████████████████████████████████████████| 1393/1393 [05:05<00:00,  4.56it/s, loss=0.074, lr=0.020000]


Precision:  0.7737462042889577
Recall:  0.7240021869874248
Report:                precision    recall  f1-score   support

           0       0.98      0.47      0.64      4465
           1       0.83      0.55      0.66      1945
           2       1.00      0.05      0.10        93
           3       0.61      0.90      0.73      2384
           4       0.82      0.33      0.47       170
           5       0.74      0.92      0.82      7101
           6       0.00      0.00      0.00         3
           7       0.00      0.00      0.00        55
           8       0.13      0.39      0.19       179
           9       0.00      0.00      0.00         6
          10       0.00      0.00      0.00        63
          11       0.00      0.00      0.00        46
          12       0.68      0.71      0.70      1780

    accuracy                           0.72     18290
   macro avg       0.45      0.33      0.33     18290
weighted avg       0.77      0.72      0.71     18290

F1:  0.7124

100%|█████████████████████████████████████████████████████| 1393/1393 [05:04<00:00,  4.57it/s, loss=0.074, lr=0.020000]


Precision:  0.8389013387469404
Recall:  0.8359759431383269
Report:                precision    recall  f1-score   support

           0       0.83      0.97      0.90      4465
           1       0.70      0.83      0.76      1945
           2       0.72      0.30      0.42        93
           3       0.82      0.76      0.79      2384
           4       0.40      0.76      0.53       170
           5       0.89      0.89      0.89      7101
           6       0.00      0.00      0.00         3
           7       0.52      0.25      0.34        55
           8       0.64      0.15      0.24       179
           9       0.00      0.00      0.00         6
          10       0.00      0.00      0.00        63
          11       0.00      0.00      0.00        46
          12       0.95      0.56      0.71      1780

    accuracy                           0.84     18290
   macro avg       0.50      0.42      0.43     18290
weighted avg       0.84      0.84      0.83     18290

F1:  0.8281

100%|█████████████████████████████████████████████████████| 1393/1393 [05:04<00:00,  4.57it/s, loss=0.074, lr=0.020000]


Precision:  0.858342472838464
Recall:  0.855385456533625
Report:                precision    recall  f1-score   support

           0       0.90      0.97      0.93      4465
           1       0.73      0.80      0.76      1945
           2       0.82      0.48      0.61        93
           3       0.77      0.80      0.79      2384
           4       0.30      0.75      0.43       170
           5       0.94      0.89      0.92      7101
           6       0.00      0.00      0.00         3
           7       0.69      0.20      0.31        55
           8       0.55      0.34      0.42       179
           9       0.00      0.00      0.00         6
          10       0.00      0.00      0.00        63
          11       0.00      0.00      0.00        46
          12       0.84      0.73      0.78      1780

    accuracy                           0.86     18290
   macro avg       0.50      0.46      0.46     18290
weighted avg       0.86      0.86      0.85     18290

F1:  0.853936

 58%|███████████████████████████████▍                      | 810/1393 [02:58<01:54,  5.08it/s, loss=0.074, lr=0.020000]

In [23]:
def create_model():
    model = torchvision.models.resnet18(weights = "ResNet18_Weights.IMAGENET1K_V1")
    model.fc = torch.nn.Linear(512, 16)
    checkpoint = torch.load(".\\classifier\\best88.pth.tar")
    checkpoint['state_dict']
    model.load_state_dict(checkpoint['state_dict'])
    # Update the last FC layer for Tiny ImageNet number of classes.
    return model

time: 0 ns (started: 2024-05-06 06:48:45 +06:00)


In [24]:
model = create_model()
model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

time: 329 ms (started: 2024-05-06 06:48:46 +06:00)


In [25]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp_delme.p")
    print('Size (MB):', os.path.getsize("temp_delme.p")/1e6)
    os.remove('temp_delme.p')
print(count_parameters(model))
print(print_size_of_model(model))

11184720
Size (MB): 44.816376
None
time: 47 ms (started: 2024-05-06 06:48:48 +06:00)


In [26]:
model.to('cuda')
evaluate_kd(model, val_dataloader)

Precision:  0.9008639945306363
Recall:  0.9041006014215418
Report:                precision    recall  f1-score   support

           0       0.95      0.96      0.96      4465
           1       0.87      0.82      0.84      1945
           2       0.81      0.45      0.58        93
           3       0.84      0.89      0.86      2384
           4       0.62      0.73      0.67       170
           5       0.93      0.97      0.95      7101
           6       0.00      0.00      0.00         3
           7       0.61      0.49      0.55        55
           8       0.71      0.38      0.49       179
           9       0.00      0.00      0.00         6
          10       0.23      0.11      0.15        63
          11       0.28      0.20      0.23        46
          12       0.90      0.77      0.83      1780

    accuracy                           0.90     18290
   macro avg       0.59      0.52      0.55     18290
weighted avg       0.90      0.90      0.90     18290

F1:  0.9008

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'accuracy': 90.41006014215418, 'loss': 0.0}

time: 14.1 s (started: 2024-05-06 06:48:49 +06:00)


In [28]:
import nncf
import openvino as ov

INFO:nncf:NNCF initialized successfully. Supported frameworks detected: torch, openvino
time: 750 ms (started: 2024-05-06 07:05:00 +06:00)


In [19]:
def transform_fn(data_item):
    images, _ = data_item
    return images

calibration_dataset = nncf.Dataset(train_dataloader, transform_fn)

time: 0 ns (started: 2024-05-05 17:36:52 +06:00)


In [20]:
model.to('cpu')

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

time: 32 ms (started: 2024-05-05 17:36:52 +06:00)


In [21]:
quantized_model = nncf.quantize(model, calibration_dataset)



INFO:nncf:Compiling and loading torch extension: quantized_functions_cpu...
Reason: Command '['where', 'cl']' returned non-zero exit status 1.
INFO:nncf:Finished loading torch extension: quantized_functions_cpu




time: 4min 20s (started: 2024-05-05 17:36:52 +06:00)


In [22]:
quantized_model

ResNet(
  (conv1): NNCFConv2d(
    3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    (pre_ops): ModuleDict(
      (0): UpdateWeight(
        (op): SymmetricQuantizer(bit=8, ch=True)
      )
    )
    (post_ops): ModuleDict()
  )
  (bn1): NNCFBatchNorm2d(
    64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
    (pre_ops): ModuleDict()
    (post_ops): ModuleDict()
  )
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): NNCFConv2d(
        64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
        (pre_ops): ModuleDict(
          (0): UpdateWeight(
            (op): SymmetricQuantizer(bit=8, ch=True)
          )
        )
        (post_ops): ModuleDict()
      )
      (bn1): NNCFBatchNorm2d(
        64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
        (pre_ops): ModuleDict()
        (p

time: 0 ns (started: 2024-05-05 17:46:06 +06:00)


In [23]:
quantized_model.to('cpu')
evaluate_kd(quantized_model, val_dataloader, device = 'cpu')

Precision:  0.9055983207731232
Recall:  0.9110442864953526
Report:                precision    recall  f1-score   support

           0       0.95      0.97      0.96      4465
           1       0.88      0.84      0.86      1945
           2       0.83      0.38      0.52        93
           3       0.86      0.89      0.87      2384
           4       0.64      0.84      0.72       170
           5       0.93      0.97      0.95      7101
           6       0.00      0.00      0.00         3
           7       0.90      0.47      0.62        55
           8       0.66      0.41      0.51       179
           9       0.00      0.00      0.00         6
          10       0.17      0.03      0.05        63
          11       0.00      0.00      0.00        46
          12       0.91      0.80      0.86      1780

    accuracy                           0.91     18290
   macro avg       0.59      0.51      0.53     18290
weighted avg       0.91      0.91      0.91     18290

F1:  0.9065

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'accuracy': 91.10442864953527, 'loss': 0.0}

time: 2min 24s (started: 2024-05-05 17:46:10 +06:00)


In [24]:
print(count_parameters(quantized_model))
print(print_size_of_model(quantized_model))

11189566
Size (MB): 44.923096
None
time: 125 ms (started: 2024-05-05 17:48:34 +06:00)


In [25]:
dummy_input = torch.randn(128, 3, 64,64)

quantized_model_ir = ov.convert_model(quantized_model, example_input=dummy_input, input=[-1, 3, 64,64])

ov.save_model(quantized_model_ir, ".\\classifier\\quantized_model_focal.xml")

  return self._level_low.item()
  return self._level_high.item()
Tensor-likes are not close!

Mismatched elements: 2036 / 2048 (99.4%)
Greatest absolute difference: 0.06525492668151855 at index (58, 3) (up to 1e-05 allowed)
Greatest relative difference: 10.966298445983899 at index (17, 1) (up to 1e-05 allowed)
  _check_trace(


time: 8.91 s (started: 2024-05-05 17:48:34 +06:00)


In [31]:
from openvino.runtime.ie_api import CompiledModel
from typing import Union


def validate(model, val_loader):
    """Compute the metrics using data from val_loader for the model"""
    total = 0
    correct = 0
    data_sc = []
    data_rc = []
    start_time = time.time()
    # Switch to evaluate mode.
    if not isinstance(model, CompiledModel):
        model.eval()
        model.to(torch_device)

    with torch.no_grad():
        end = time.time()
        for i, (images, labels_batch) in enumerate(val_loader):
            images = images.to('cpu')
            labels_batch = labels_batch.to('cpu')

            # Compute the output.
            if isinstance(model, CompiledModel):
                output_layer = model.output(0)
                output = model(images)[output_layer]
                output_batch = torch.from_numpy(output)
#             print(output)
            loss = 0.0  # force validation loss to zero to reduce computation time
            _, predicted = output_batch.max(1)
            total += labels_batch.size(0)
            correct += predicted.eq(labels_batch).sum().item()
            predicted = predicted.cpu().data.numpy()
            labels_batch = labels_batch.cpu().data.numpy()
            data_sc.extend(predicted)
            data_rc.extend(labels_batch)

    #         break
    #     print(data_rc)
        fscore = f1_score(data_rc, data_sc, average = 'weighted')
        precision = precision_score(data_rc, data_sc, average = 'weighted')
        recall = recall_score(data_rc, data_sc, average = 'weighted')
        kappa = cohen_kappa_score(data_rc, data_sc)
        C = classification_report(data_rc, data_sc)
        matrix = confusion_matrix(data_rc, data_sc)
        accuracy_class = matrix.diagonal()/matrix.sum(axis=1)

        print("Precision: ", precision)
        print("Recall: ", recall)
        print("Report: ", C)
        print("F1: ", fscore)
        print("Kappa: ", kappa)
        print("Matrix: ", matrix)
        print("Accuracy Class: ", accuracy_class)

        acc = 100. * correct / total
        logging.info("- Eval metrics, acc:{acc:.4f}, loss: {loss:.4f}".format(acc=acc, loss=loss))
        my_metric = {'accuracy': acc, 'loss': loss}
        #my_metric['accuracy'] = acc
        return my_metric

time: 0 ns (started: 2024-05-06 07:05:34 +06:00)


In [27]:
core = ov.Core()
int8_compiled_model = core.compile_model(quantized_model_ir, 'CPU')
validate(int8_compiled_model, val_dataloader )

Precision:  0.9053496938631382
Recall:  0.9108255877528704
Report:                precision    recall  f1-score   support

           0       0.95      0.96      0.96      4465
           1       0.88      0.84      0.86      1945
           2       0.82      0.39      0.53        93
           3       0.86      0.89      0.88      2384
           4       0.64      0.83      0.72       170
           5       0.93      0.97      0.95      7101
           6       0.00      0.00      0.00         3
           7       0.90      0.47      0.62        55
           8       0.66      0.41      0.51       179
           9       0.00      0.00      0.00         6
          10       0.17      0.03      0.05        63
          11       0.00      0.00      0.00        46
          12       0.91      0.80      0.86      1780

    accuracy                           0.91     18290
   macro avg       0.59      0.51      0.53     18290
weighted avg       0.91      0.91      0.91     18290

F1:  0.9062

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'accuracy': 91.08255877528704, 'loss': 0.0}

time: 13.1 s (started: 2024-05-05 17:48:43 +06:00)


In [29]:
int8_ir_path = ".\\classifier\\quantized_model_2.xml"
print(f"[2/7] Save INT8 model: {int8_ir_path}")

core = ov.Core()

model_q = core.read_model(model=int8_ir_path)
model_q = core.compile_model(model=model_q, device_name='CPU')

[2/7] Save INT8 model: .\classifier\quantized_model_2.xml
time: 297 ms (started: 2024-05-06 07:05:18 +06:00)


In [32]:
validate(model_q, val_dataloader )

Precision:  0.8763404612825163
Recall:  0.8769819573537452
Report:                precision    recall  f1-score   support

           0       0.95      0.87      0.91      4465
           1       0.86      0.81      0.83      1945
           2       0.36      0.37      0.36        93
           3       0.74      0.91      0.82      2384
           4       0.65      0.72      0.68       170
           5       0.92      0.96      0.94      7101
           6       0.00      0.00      0.00         3
           7       0.73      0.40      0.52        55
           8       0.56      0.36      0.44       179
           9       0.00      0.00      0.00         6
          10       0.27      0.11      0.16        63
          11       0.00      0.00      0.00        46
          12       0.85      0.77      0.80      1780

    accuracy                           0.88     18290
   macro avg       0.53      0.48      0.50     18290
weighted avg       0.88      0.88      0.87     18290

F1:  0.8743

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'accuracy': 87.69819573537453, 'loss': 0.0}

time: 12.4 s (started: 2024-05-06 07:05:39 +06:00)


In [None]:
# class QuantizedResNet18(torch.nn.Module):
#     def __init__(self, model_fp32):
#         super(QuantizedResNet18, self).__init__()
  
#         self.quant = torch.quantization.QuantStub()
#         # FP32 model 
#         self.model_fp32 = model_fp32
#         self.dequant = torch.quantization.DeQuantStub()


#     def forward(self, x):
#         x = self.quant(x)
#         x = self.model_fp32(x)
#         x = self.dequant(x)
#         return x

In [None]:
# model_q = QuantizedResNet18(model)

In [None]:
# model_q.eval()
# model_q.qconfig = torch.ao.quantization.default_qconfig
# model_q = torch.ao.quantization.prepare(model_q)
# model_q

In [None]:
00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000+
    