In [1]:
#imports
import numpy as np
from skimage import io
from skimage.transform import resize
import matplotlib.pyplot as plt
import random
import matplotlib.patches as patches
import os

import torch
import torchvision
from torchvision import ops
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
#imports
from tqdm import tqdm
import pickle

In [3]:
#To use in Colab
!git clone https://github.com/gabriellecaillaud/APS360_Traffic_Sign_Recognition.git

Cloning into 'APS360_Traffic_Sign_Recognition'...
remote: Enumerating objects: 15398, done.[K
remote: Counting objects: 100% (74/74), done.[K
remote: Compressing objects: 100% (56/56), done.[K
remote: Total 15398 (delta 30), reused 51 (delta 16), pack-reused 15324[K
Receiving objects: 100% (15398/15398), 628.47 MiB | 40.61 MiB/s, done.
Resolving deltas: 100% (71/71), done.
Updating files: 100% (5347/5347), done.


In [2]:

#path csv with labels
csv_path = "/content/APS360_Traffic_Sign_Recognition/dataset_traffic_signs.csv"

In [3]:
class ObjectDetectionDataset(Dataset):
    '''
    A Pytorch Dataset class to load the images and their corresponding annotations.
    
    Returns
    ------------
    images: torch.Tensor of size (B, C, H, W)
    gt bboxes: torch.Tensor of size (B, max_objects, 4)
    gt classes: torch.Tensor of size (B, max_objects)
    '''
    def __init__(self, csv_path, img_size, name2idx):
        self.annotation_path = csv_path
        self.img_size = img_size
        self.name2idx = name2idx
        
        self.img_data_all, self.gt_bboxes_all, self.gt_classes_all = self.get_data()
        
    def __len__(self):
        return self.img_data_all.size(dim=0)
    
    def __getitem__(self, idx):
        return self.img_data_all[idx], self.gt_bboxes_all[idx], self.gt_classes_all[idx]
        
    def get_data(self):
        img_data_all = []
        gt_idxs_all = []
        
        gt_boxes_all, gt_classes_all, img_paths = parse_annotation(self.annotation_path, self.img_size)
        
        for i, img_path in tqdm(enumerate(img_paths), total=len(img_paths)):
      
            # skip if the image path is not valid
            if (not img_path) or (not os.path.exists(img_path)):
                continue
            
            # read and resize image
            
            img = io.imread(img_path)
            img = resize(img, self.img_size)
            
            # convert image to torch tensor and reshape it so channels come first
            img_tensor = torch.from_numpy(img).permute(2, 0, 1)
            
            # encode class names as integers
            gt_classes = gt_classes_all[i]
            gt_idx = torch.Tensor([self.name2idx[name] for name in gt_classes])
            
            img_data_all.append(img_tensor)
            gt_idxs_all.append(gt_idx)
        
        # pad bounding boxes and classes so they are of the same size

        if len(gt_boxes_all)!=0 and len(gt_idxs_all)!=0 :
          
          gt_bboxes_pad = pad_sequence(gt_boxes_all, batch_first=True, padding_value=-1)
          gt_classes_pad = pad_sequence(gt_idxs_all, batch_first=True, padding_value=-1)
        
        # stack all images
        img_data_stacked = torch.stack(img_data_all)[:, :3, :, :]
        
        return img_data_stacked.to(dtype=torch.float32), gt_bboxes_pad, gt_classes_pad

    def __getstate__(self):
        state = self.__dict__.copy()
        state['img_data_all'] = pickle.dumps(state['img_data_all'])
        state['gt_bboxes_all'] = pickle.dumps(state['gt_bboxes_all'])
        state['gt_classes_all'] = pickle.dumps(state['gt_classes_all'])
        return state

    def __setstate__(self, state):
        state['img_data_all'] = pickle.loads(state['img_data_all'])
        state['gt_bboxes_all'] = pickle.loads(state['gt_bboxes_all'])
        state['gt_classes_all'] = pickle.loads(state['gt_classes_all'])
        self.__dict__.update(state)

In [4]:
img_width = 128
img_height = 128
csv_path = "/content/APS360_Traffic_Sign_Recognition/dataset_traffic_signs.csv"
image_dir = os.path.join("data", "images")
name2idx = {'pad': -1, '30kmh': 0,'60kmh':1, '100kmh' : 2, 'yield': 3, 'keepRight' :4, 'NoEntry':5, 'NoLeft': 6, 'Stop':7, 'noRight':8, 'ChildrenCrossing' :9 }
idx2name = {v:k for k, v in name2idx.items()}

In [5]:
#Load the custom modules on colab based on the files on git
import sys
 #To use on colab
sys.path.append('/content/APS360_Traffic_Sign_Recognition')

# might be neccessary to change to from RCNN_model.utils import * in model
#from RCNN_model.model import *
from RCNN_model.utils import *

In [22]:
import torch
import torchvision
from torchvision import ops
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn

#from utils import *

class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        model = torchvision.models.resnet50(pretrained=True)
        req_layers = list(model.children())[:8]
        self.backbone = nn.Sequential(*req_layers)
        for param in self.backbone.named_parameters():
            param[1].requires_grad = True
        
    def forward(self, img_data):
        return self.backbone(img_data)
    
    def freeze_layers(self, num_layers):
        for i, child in enumerate(self.backbone.children()):
            if i < num_layers:
                for param in child.parameters():
                    param.requires_grad = False
            else:
                break
        

class ProposalModule(nn.Module):
    def __init__(self, in_features, hidden_dim=128, n_anchors=9, p_dropout=0.3):
        super().__init__()
        self.n_anchors = n_anchors
        self.conv1 = nn.Conv2d(in_features, hidden_dim, kernel_size=3, padding=1)
        self.dropout = nn.Dropout(p_dropout)
        self.conf_head = nn.Conv2d(hidden_dim, n_anchors, kernel_size=1)
        self.reg_head = nn.Conv2d(hidden_dim, n_anchors * 4, kernel_size=1)
        
    def forward(self, feature_map, pos_anc_ind=None, neg_anc_ind=None, pos_anc_coords=None):
        # determine mode
        if pos_anc_ind is None or neg_anc_ind is None or pos_anc_coords is None:
            mode = 'eval'
        else:
            mode = 'train'
            
        out = self.conv1(feature_map)
        out = F.relu(self.dropout(out))
        
        reg_offsets_pred = self.reg_head(out) # (B, A*4, hmap, wmap)
        conf_scores_pred = self.conf_head(out) # (B, A, hmap, wmap)
        
        if mode == 'train': 
            # get conf scores 
            conf_scores_pos = conf_scores_pred.flatten()[pos_anc_ind]
            conf_scores_neg = conf_scores_pred.flatten()[neg_anc_ind]
            # get offsets for +ve anchors
            offsets_pos = reg_offsets_pred.contiguous().view(-1, 4)[pos_anc_ind]
            # generate proposals using offsets
            proposals = generate_proposals(pos_anc_coords, offsets_pos)
            
            return conf_scores_pos, conf_scores_neg, offsets_pos, proposals
            
        elif mode == 'eval':
            return conf_scores_pred, reg_offsets_pred
           
class RegionProposalNetwork(nn.Module):
    def __init__(self, img_size, out_size, out_channels):
        super().__init__()
        
        self.img_height, self.img_width = img_size
        self.out_h, self.out_w = out_size
        
        # downsampling scale factor 
        self.width_scale_factor = self.img_width // self.out_w
        self.height_scale_factor = self.img_height // self.out_h 
        
        # scales and ratios for anchor boxes
        self.anc_scales = [2, 4, 6]
        self.anc_ratios = [0.5, 1, 1.5]
        self.n_anc_boxes = len(self.anc_scales) * len(self.anc_ratios)
        
        # IoU thresholds for +ve and -ve anchors
        self.pos_thresh = 0.7
        self.neg_thresh = 0.3
        
        # weights for loss
        self.w_conf = 1
        self.w_reg = 5
        
        self.feature_extractor = FeatureExtractor()
        #freezing the first 6 layers
        self.feature_extractor.freeze_layers(8)
        self.proposal_module = ProposalModule(out_channels, n_anchors=self.n_anc_boxes)
        
    def forward(self, images, gt_bboxes, gt_classes):
        batch_size = images.size(dim=0)
        feature_map = self.feature_extractor(images)
        
        # generate anchors
        anc_pts_x, anc_pts_y = gen_anc_centers(out_size=(self.out_h, self.out_w))
        anc_base = gen_anc_base(anc_pts_x, anc_pts_y, self.anc_scales, self.anc_ratios, (self.out_h, self.out_w))
        anc_boxes_all = anc_base.repeat(batch_size, 1, 1, 1, 1)
        
        # get positive and negative anchors amongst other things
        gt_bboxes_proj = project_bboxes(gt_bboxes, self.width_scale_factor, self.height_scale_factor, mode='p2a')
        
        positive_anc_ind, negative_anc_ind, GT_conf_scores, \
        GT_offsets, GT_class_pos, positive_anc_coords, \
        negative_anc_coords, positive_anc_ind_sep = get_req_anchors(anc_boxes_all, gt_bboxes_proj, gt_classes)
        
        # pass through the proposal module
        conf_scores_pos, conf_scores_neg, offsets_pos, proposals = self.proposal_module(feature_map, positive_anc_ind, \
                                                                                        negative_anc_ind, positive_anc_coords)
        
        cls_loss = calc_cls_loss(conf_scores_pos, conf_scores_neg, batch_size)
        reg_loss = calc_bbox_reg_loss(GT_offsets, offsets_pos, batch_size)
        
        total_rpn_loss = self.w_conf * cls_loss + self.w_reg * reg_loss
        
        return total_rpn_loss, feature_map, proposals, positive_anc_ind_sep, GT_class_pos

    def inference(self, images, conf_thresh=0.5, nms_thresh=0.7):
        with torch.no_grad():
            batch_size = images.size(dim=0)
            feature_map = self.feature_extractor(images)

            # generate anchors
            anc_pts_x, anc_pts_y = gen_anc_centers(out_size=(self.out_h, self.out_w))
            anc_base = gen_anc_base(anc_pts_x, anc_pts_y, self.anc_scales, self.anc_ratios, (self.out_h, self.out_w))
            anc_boxes_all = anc_base.repeat(batch_size, 1, 1, 1, 1)
            anc_boxes_flat = anc_boxes_all.reshape(batch_size, -1, 4)

            # get conf scores and offsets
            conf_scores_pred, offsets_pred = self.proposal_module(feature_map)
            conf_scores_pred = conf_scores_pred.reshape(batch_size, -1)
            offsets_pred = offsets_pred.reshape(batch_size, -1, 4)

            # filter out proposals based on conf threshold and nms threshold for each image
            proposals_final = []
            conf_scores_final = []
            for i in range(batch_size):
                conf_scores = torch.sigmoid(conf_scores_pred[i])
                offsets = offsets_pred[i]
                anc_boxes = anc_boxes_flat[i]
                proposals = generate_proposals(anc_boxes, offsets)
                # filter based on confidence threshold
                conf_idx = torch.where(conf_scores >= conf_thresh)[0]
                conf_scores_pos = conf_scores[conf_idx]
                proposals_pos = proposals[conf_idx]
                # filter based on nms threshold
                nms_idx = ops.nms(proposals_pos, conf_scores_pos, nms_thresh)
                conf_scores_pos = conf_scores_pos[nms_idx]
                proposals_pos = proposals_pos[nms_idx]

                proposals_final.append(proposals_pos)
                conf_scores_final.append(conf_scores_pos)
            
        return proposals_final, conf_scores_final, feature_map
    
class ClassificationModule(nn.Module):
    def __init__(self, out_channels, n_classes, roi_size, hidden_dim=64, p_dropout=0.3):
        super().__init__()        
        self.roi_size = roi_size
        # hidden network
        self.avg_pool = nn.AvgPool2d(self.roi_size)
        self.fc = nn.Linear(out_channels, hidden_dim)
        self.dropout = nn.Dropout(p_dropout)
        
        # define classification head
        self.cls_head = nn.Linear(hidden_dim, n_classes)
        
    def forward(self, feature_map, proposals_list, gt_classes=None):
        
        if gt_classes is None:
            mode = 'eval'
        else:
            mode = 'train'
        
        # apply roi pooling on proposals followed by avg pooling
        roi_out = ops.roi_pool(feature_map, proposals_list, self.roi_size)
        roi_out = self.avg_pool(roi_out)
        
        # flatten the output
        roi_out = roi_out.squeeze(-1).squeeze(-1)
        
        # pass the output through the hidden network
        out = self.fc(roi_out)
        out = F.relu(self.dropout(out))
        
        # get the classification scores
        cls_scores = self.cls_head(out)
        
        if mode == 'eval':
            return cls_scores
        
        # compute cross entropy loss
        cls_loss = F.cross_entropy(cls_scores, gt_classes.long())
        
        return cls_loss
    
class TwoStageDetector(nn.Module):
    def __init__(self, img_size, out_size, out_channels, n_classes, roi_size):
        super().__init__() 
        self.rpn = RegionProposalNetwork(img_size, out_size, out_channels)
        self.classifier = ClassificationModule(out_channels, n_classes, roi_size)
        
    def forward(self, images, gt_bboxes, gt_classes):
        total_rpn_loss, feature_map, proposals, \
        positive_anc_ind_sep, GT_class_pos = self.rpn(images, gt_bboxes, gt_classes)
        
        # get separate proposals for each sample
        pos_proposals_list = []
        batch_size = images.size(dim=0)
        for idx in range(batch_size):
            proposal_idxs = torch.where(positive_anc_ind_sep == idx)[0]
            proposals_sep = proposals[proposal_idxs].detach().clone()
            pos_proposals_list.append(proposals_sep)
        
        cls_loss = self.classifier(feature_map, pos_proposals_list, GT_class_pos)
        total_loss = cls_loss + total_rpn_loss
        
        return total_loss
    
    def inference(self, images, conf_thresh=0.5, nms_thresh=0.7):
        batch_size = images.size(dim=0)
        proposals_final, conf_scores_final, feature_map = self.rpn.inference(images, conf_thresh, nms_thresh)
        cls_scores = self.classifier(feature_map, proposals_final)
        
        # convert scores into probability
        cls_probs = F.softmax(cls_scores, dim=-1)
        # get classes with highest probability
        classes_all = torch.argmax(cls_probs, dim=-1)
        
        classes_final = []
        # slice classes to map to their corresponding image
        c = 0
        for i in range(batch_size):
            n_proposals = len(proposals_final[i]) # get the number of proposals for each image
            classes_final.append(classes_all[c: c+n_proposals])
            c += n_proposals
            
        return proposals_final, conf_scores_final, classes_final

# ------------------- Loss Utils ----------------------
def calc_cls_loss(conf_scores_pos, conf_scores_neg, batch_size):
    """
    Calculate the classification loss (binary cross-entropy loss) for the RPN.

    Args:
        conf_scores_pos (Tensor): the predicted class scores for the positive anchors
        conf_scores_neg (Tensor): the predicted class scores for the negative anchors
        batch_size (int): the number of samples in the batch

    Returns:
        Tensor: the binary cross-entropy loss
    """

    # concatenate the positive and negative scores
    inputs = torch.cat((conf_scores_pos, conf_scores_neg), dim=0)

    # create the target tensor
    target = torch.cat((torch.ones(conf_scores_pos.shape[0], dtype=torch.float), 
                        torch.zeros(conf_scores_neg.shape[0], dtype=torch.float)), dim=0)

    # convert the target tensor to float
    target = target.float()

    # calculate binary cross entropy loss
    loss_fn = nn.CrossEntropyLoss(weight=None, reduction='mean')
    loss = loss_fn(inputs, target)
    clip_value = 1e8
    if clip_value is not None:
        torch.nn.utils.clip_grad_norm_(loss_fn.parameters(), clip_value) 

    return loss


#def calc_cls_loss(conf_scores_pos, conf_scores_neg, batch_size):
    #target_pos = torch.ones_like(conf_scores_pos)
    #target_neg = torch.zeros_like(conf_scores_neg)
    
    #target = torch.cat((target_pos, target_neg))
    #inputs = torch.cat((conf_scores_pos, conf_scores_neg))
     
    #loss = F.binary_cross_entropy_with_logits(inputs, target, reduction='sum') * 1. / batch_size
    #loss_fn = nn.CrossEntropyLoss(inputs, target, reduction = 'sum')*1./batch_size
    #return loss_fn

def calc_bbox_reg_loss(gt_offsets, reg_offsets_pos, batch_size):
    assert gt_offsets.size() == reg_offsets_pos.size()
    loss = F.smooth_l1_loss(reg_offsets_pos, gt_offsets, reduction='sum') * 1. / batch_size
    return loss



In [91]:
#function taken from tutorial 3b
def get_accuracy(model, data):
    
    correct = 0
    total = 0
    model.eval() #*********#
    dataloader = DataLoader(data, batch_size=64)
    for img_batch, gt_bboxes_batch, gt_classes_batch in tqdm(dataloader):
        proposals_final, conf_scores_final, classes_final = model.inference(img_batch, conf_thresh=0.90, nms_thresh=0.05)
        
        for i in range(len(gt_classes_batch)):
                y_true = gt_classes_batch[i]
                 # check if the tensor is empty
                if conf_scores_final[1].numel() == 0:
                  continue
                idx = torch.argmax(conf_scores_final[1])
                current_class_ = classes_final[1][idx]
                current_class = current_class_.unsqueeze(0)
                if current_class.item()== y_true.item():
                  correct += 1
                total += 1
    return correct/total


In [84]:
get_accuracy(detector, val_dataset)

100%|██████████| 10/10 [00:38<00:00,  3.88s/it]


0.15141955835962145

In [11]:
datset = ObjectDetectionDataset(csv_path, (img_height, img_width), name2idx)

100%|██████████| 3688/3688 [00:30<00:00, 121.58it/s]


In [12]:
od_dataloader = DataLoader(datset, batch_size=32)

In [8]:
img_size = (img_height, img_width)
out_size = (4,4) ## see other d
n_classes = len(name2idx) - 1 # exclude pad idx
roi_size = (2, 2)
out_c = 2048

In [13]:
detector = TwoStageDetector(img_size, out_size, out_c, n_classes, roi_size)



In [14]:
learning_rate = 1e-2
n_epochs = 2
loss_list = training_loop(detector, learning_rate, od_dataloader, n_epochs)

  0%|          | 0/2 [00:00<?, ?it/s]

iteration 0
loss tensor(10.8066, grad_fn=<AddBackward0>)
iteration 0
loss tensor(561.8773, grad_fn=<AddBackward0>)
iteration 0
loss tensor(2214.8215, grad_fn=<AddBackward0>)
iteration 0
loss tensor(2078.7554, grad_fn=<AddBackward0>)
iteration 0
loss tensor(691.8011, grad_fn=<AddBackward0>)
iteration 0
loss tensor(309.1836, grad_fn=<AddBackward0>)
iteration 0
loss tensor(47.9379, grad_fn=<AddBackward0>)
iteration 0
loss tensor(145.9836, grad_fn=<AddBackward0>)
iteration 0
loss tensor(7.6316, grad_fn=<AddBackward0>)
iteration 0
loss tensor(168.1086, grad_fn=<AddBackward0>)
iteration 0
loss tensor(18.2922, grad_fn=<AddBackward0>)
iteration 0
loss tensor(63.2084, grad_fn=<AddBackward0>)
iteration 0
loss tensor(10.7894, grad_fn=<AddBackward0>)
iteration 0
loss tensor(62.7178, grad_fn=<AddBackward0>)
iteration 0
loss tensor(nan, grad_fn=<AddBackward0>)
Skipping NaN loss
iteration 0
loss tensor(287.0280, grad_fn=<AddBackward0>)
iteration 0
loss tensor(nan, grad_fn=<AddBackward0>)
Skipping NaN

 50%|█████     | 1/2 [04:03<04:03, 243.58s/it]

loss tensor(nan, grad_fn=<AddBackward0>)
Skipping NaN loss
iteration 1
loss tensor(34.8989, grad_fn=<AddBackward0>)
iteration 1
loss tensor(225.6290, grad_fn=<AddBackward0>)
iteration 1
loss tensor(142.9758, grad_fn=<AddBackward0>)
iteration 1
loss tensor(409.4698, grad_fn=<AddBackward0>)
iteration 1
loss tensor(119.5104, grad_fn=<AddBackward0>)
iteration 1
loss tensor(58.1959, grad_fn=<AddBackward0>)
iteration 1
loss tensor(10.0574, grad_fn=<AddBackward0>)
iteration 1
loss tensor(5.5498, grad_fn=<AddBackward0>)
iteration 1
loss tensor(2.8526, grad_fn=<AddBackward0>)
iteration 1
loss tensor(19.9750, grad_fn=<AddBackward0>)
iteration 1
loss tensor(2.2267, grad_fn=<AddBackward0>)
iteration 1
loss tensor(71.7837, grad_fn=<AddBackward0>)
iteration 1
loss tensor(0.1548, grad_fn=<AddBackward0>)
iteration 1
loss tensor(6.7257, grad_fn=<AddBackward0>)
iteration 1
loss tensor(nan, grad_fn=<AddBackward0>)
Skipping NaN loss
iteration 1
loss tensor(27.7759, grad_fn=<AddBackward0>)
iteration 1
loss

100%|██████████| 2/2 [08:16<00:00, 248.13s/it]

loss tensor(nan, grad_fn=<AddBackward0>)
Skipping NaN loss





In [15]:
loss_list

[12142.567273378372, 1864.0959092974663]

In [16]:
torch.save(detector.state_dict(), 'detector.pt')

In [87]:
def training_loop(model, learning_rate, train_dataloader, val_dataset, n_epochs):
    
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    loss_fn = nn.CrossEntropyLoss(weight=None, reduction='mean')
    
    model.train()
    train_loss_list = []
    val_loss_list = []
    val_acc_list = []
    
    for i in tqdm(range(n_epochs)):
        # Train
        train_total_loss = 0
        for img_batch, gt_bboxes_batch, gt_classes_batch in tqdm(train_dataloader):
            # Forward pass
            train_loss = model(img_batch, gt_bboxes_batch, gt_classes_batch)
            
            if torch.isnan(train_loss):
                print("Skipping NaN loss")
                continue
            
            # Backpropagation
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()
            
            train_total_loss += train_loss.item()
        
        train_loss_list.append(train_total_loss)
        
        # Validation
        model.eval()
        val_total_loss = 0
        val_total_correct = 0
        val_total_samples = 0
        val_dataloader = DataLoader(val_dataset, batch_size = 64)
        with torch.no_grad():
            for img_batch, gt_bboxes_batch, gt_classes_batch in val_dataloader:
                # Forward pass
                val_loss = model(img_batch, gt_bboxes_batch, gt_classes_batch)
                val_total_loss += val_loss.item()
               
        val_loss_list.append(val_total_loss)
        val_acc = get_accuracy(model,val_dataset)
        val_acc_list.append(val_acc)
        
        model.train()
        
    return train_loss_list, val_loss_list, val_acc_list


In [86]:
val_csv_path = "/content/APS360_Traffic_Sign_Recognition/val.csv"

In [20]:
val_dataset = ObjectDetectionDataset(val_csv_path, (img_height, img_width), name2idx)

100%|██████████| 729/729 [00:04<00:00, 149.64it/s]


In [21]:
val_dataloader = DataLoader(val_dataset, batch_size = 32)

In [88]:
detector = TwoStageDetector(img_size, out_size, out_c, n_classes, roi_size)



In [90]:
lr = 0.05
n_epochs = 2
train_loss_list, val_loss_list, val_acc_list = training_loop(detector, lr, od_dataloader, val_dataset, n_epochs)

  0%|          | 0/2 [00:00<?, ?it/s]
  0%|          | 0/102 [00:00<?, ?it/s][A
  1%|          | 1/102 [00:02<03:57,  2.35s/it][A
  2%|▏         | 2/102 [00:04<03:47,  2.28s/it][A
  3%|▎         | 3/102 [00:07<04:23,  2.66s/it][A
  4%|▍         | 4/102 [00:10<04:07,  2.52s/it][A
  5%|▍         | 5/102 [00:12<03:53,  2.40s/it][A
  6%|▌         | 6/102 [00:14<03:44,  2.34s/it][A
  7%|▋         | 7/102 [00:16<03:38,  2.30s/it][A
  8%|▊         | 8/102 [00:19<03:48,  2.43s/it][A
  9%|▉         | 9/102 [00:21<03:51,  2.49s/it][A
 10%|▉         | 10/102 [00:24<03:40,  2.40s/it][A
 11%|█         | 11/102 [00:26<03:33,  2.35s/it][A
 12%|█▏        | 12/102 [00:28<03:27,  2.31s/it][A
 13%|█▎        | 13/102 [00:30<03:25,  2.31s/it][A
 14%|█▎        | 14/102 [00:34<03:43,  2.54s/it][A
 15%|█▍        | 15/102 [00:36<03:30,  2.42s/it][A

Skipping NaN loss



 16%|█▌        | 16/102 [00:38<03:23,  2.36s/it][A
 17%|█▋        | 17/102 [00:40<03:14,  2.29s/it][A

Skipping NaN loss



 18%|█▊        | 18/102 [00:42<03:10,  2.27s/it][A
 19%|█▊        | 19/102 [00:45<03:26,  2.48s/it][A

Skipping NaN loss



 20%|█▉        | 20/102 [00:48<03:20,  2.45s/it][A
 21%|██        | 21/102 [00:50<03:12,  2.38s/it][A
 22%|██▏       | 22/102 [00:52<03:06,  2.34s/it][A
 23%|██▎       | 23/102 [00:54<03:00,  2.28s/it][A

Skipping NaN loss



 24%|██▎       | 24/102 [00:57<03:07,  2.40s/it][A
 25%|██▍       | 25/102 [01:00<03:11,  2.49s/it][A
 25%|██▌       | 26/102 [01:02<03:03,  2.41s/it][A
 26%|██▋       | 27/102 [01:04<02:56,  2.35s/it][A
 27%|██▋       | 28/102 [01:06<02:49,  2.29s/it][A

Skipping NaN loss



 28%|██▊       | 29/102 [01:08<02:45,  2.27s/it][A
 29%|██▉       | 30/102 [01:11<03:02,  2.53s/it][A
 30%|███       | 31/102 [01:14<02:52,  2.43s/it][A
 31%|███▏      | 32/102 [01:16<02:46,  2.38s/it][A
 32%|███▏      | 33/102 [01:18<02:40,  2.33s/it][A
 33%|███▎      | 34/102 [01:20<02:35,  2.28s/it][A

Skipping NaN loss



 34%|███▍      | 35/102 [01:24<02:53,  2.59s/it][A
 35%|███▌      | 36/102 [01:26<02:44,  2.49s/it][A
 36%|███▋      | 37/102 [01:28<02:36,  2.41s/it][A
 37%|███▋      | 38/102 [01:30<02:30,  2.36s/it][A
 38%|███▊      | 39/102 [01:33<02:24,  2.30s/it][A

Skipping NaN loss



 39%|███▉      | 40/102 [01:35<02:31,  2.45s/it][A
 40%|████      | 41/102 [01:38<02:32,  2.50s/it][A
 41%|████      | 42/102 [01:40<02:25,  2.42s/it][A
 42%|████▏     | 43/102 [01:42<02:17,  2.34s/it][A

Skipping NaN loss



 43%|████▎     | 44/102 [01:44<02:12,  2.28s/it][A

Skipping NaN loss



 44%|████▍     | 45/102 [01:47<02:10,  2.30s/it][A
 45%|████▌     | 46/102 [01:50<02:20,  2.51s/it][A
 46%|████▌     | 47/102 [01:52<02:12,  2.40s/it][A

Skipping NaN loss



 47%|████▋     | 48/102 [01:54<02:05,  2.33s/it][A

Skipping NaN loss



 48%|████▊     | 49/102 [01:56<02:00,  2.27s/it][A

Skipping NaN loss



 49%|████▉     | 50/102 [01:59<01:58,  2.27s/it][A
 50%|█████     | 51/102 [02:02<02:07,  2.51s/it][A

Skipping NaN loss



 51%|█████     | 52/102 [02:04<02:01,  2.44s/it][A

Skipping NaN loss



 52%|█████▏    | 53/102 [02:06<01:55,  2.37s/it][A

Skipping NaN loss



 53%|█████▎    | 54/102 [02:08<01:50,  2.31s/it][A

Skipping NaN loss



 54%|█████▍    | 55/102 [02:10<01:46,  2.26s/it][A

Skipping NaN loss



 55%|█████▍    | 56/102 [02:13<01:46,  2.32s/it][A

Skipping NaN loss



 56%|█████▌    | 57/102 [02:16<01:49,  2.44s/it][A

Skipping NaN loss



 57%|█████▋    | 58/102 [02:18<01:43,  2.35s/it][A

Skipping NaN loss



 58%|█████▊    | 59/102 [02:20<01:38,  2.30s/it][A

Skipping NaN loss



 59%|█████▉    | 60/102 [02:22<01:35,  2.27s/it][A
 60%|█████▉    | 61/102 [02:24<01:32,  2.26s/it][A
 61%|██████    | 62/102 [02:28<01:41,  2.54s/it][A
 62%|██████▏   | 63/102 [02:30<01:35,  2.45s/it][A
 63%|██████▎   | 64/102 [02:32<01:30,  2.38s/it][A
 64%|██████▎   | 65/102 [02:34<01:26,  2.34s/it][A
 65%|██████▍   | 66/102 [02:36<01:22,  2.30s/it][A
 66%|██████▌   | 67/102 [02:39<01:28,  2.53s/it][A
 67%|██████▋   | 68/102 [02:42<01:24,  2.48s/it][A
 68%|██████▊   | 69/102 [02:44<01:19,  2.41s/it][A
 69%|██████▊   | 70/102 [02:46<01:15,  2.35s/it][A
 70%|██████▉   | 71/102 [02:49<01:18,  2.52s/it][A
 71%|███████   | 72/102 [02:52<01:20,  2.69s/it][A
 72%|███████▏  | 73/102 [02:55<01:15,  2.59s/it][A
 73%|███████▎  | 74/102 [02:57<01:09,  2.49s/it][A
 74%|███████▎  | 75/102 [02:59<01:05,  2.41s/it][A
 75%|███████▍  | 76/102 [03:01<01:01,  2.36s/it][A
 75%|███████▌  | 77/102 [03:04<01:01,  2.44s/it][A
 76%|███████▋  | 78/102 [03:07<01:00,  2.52s/it][A
 77%|██████

Skipping NaN loss



 92%|█████████▏| 94/102 [03:45<00:19,  2.47s/it][A
 93%|█████████▎| 95/102 [03:47<00:16,  2.39s/it][A
 94%|█████████▍| 96/102 [03:49<00:13,  2.32s/it][A

Skipping NaN loss



 95%|█████████▌| 97/102 [03:52<00:11,  2.26s/it][A

Skipping NaN loss



 96%|█████████▌| 98/102 [03:54<00:09,  2.27s/it][A
 97%|█████████▋| 99/102 [03:57<00:07,  2.49s/it][A

Skipping NaN loss



 98%|█████████▊| 100/102 [03:59<00:04,  2.41s/it][A
 99%|█████████▉| 101/102 [04:01<00:02,  2.33s/it][A

Skipping NaN loss



100%|██████████| 102/102 [04:02<00:00,  2.38s/it]

Skipping NaN loss




  0%|          | 0/10 [00:00<?, ?it/s][A
 10%|█         | 1/10 [00:03<00:32,  3.66s/it][A
 20%|██        | 2/10 [00:08<00:34,  4.26s/it][A
 30%|███       | 3/10 [00:12<00:27,  3.99s/it][A
 40%|████      | 4/10 [00:15<00:23,  3.86s/it][A
 50%|█████     | 5/10 [00:20<00:20,  4.17s/it][A
 60%|██████    | 6/10 [00:24<00:16,  4.00s/it][A
 70%|███████   | 7/10 [00:27<00:11,  3.91s/it][A
 80%|████████  | 8/10 [00:32<00:08,  4.04s/it][A
 90%|█████████ | 9/10 [00:36<00:04,  4.04s/it][A
100%|██████████| 10/10 [00:39<00:00,  3.95s/it]
 50%|█████     | 1/2 [05:19<05:19, 319.50s/it]
  0%|          | 0/102 [00:00<?, ?it/s][A
  1%|          | 1/102 [00:02<03:46,  2.24s/it][A
  2%|▏         | 2/102 [00:05<04:16,  2.57s/it][A
  3%|▎         | 3/102 [00:07<04:13,  2.56s/it][A
  4%|▍         | 4/102 [00:09<03:57,  2.42s/it][A
  5%|▍         | 5/102 [00:12<03:47,  2.34s/it][A
  6%|▌         | 6/102 [00:14<03:41,  2.31s/it][A
  7%|▋         | 7/102 [00:16<03:40,  2.32s/it][A
  8%|▊     

Skipping NaN loss



 16%|█▌        | 16/102 [00:38<03:20,  2.34s/it][A
 17%|█▋        | 17/102 [00:40<03:13,  2.27s/it][A

Skipping NaN loss



 18%|█▊        | 18/102 [00:43<03:28,  2.48s/it][A
 19%|█▊        | 19/102 [00:45<03:23,  2.46s/it][A

Skipping NaN loss



 20%|█▉        | 20/102 [00:48<03:16,  2.39s/it][A
 21%|██        | 21/102 [00:50<03:09,  2.34s/it][A
 22%|██▏       | 22/102 [00:52<03:04,  2.30s/it][A
 23%|██▎       | 23/102 [00:54<03:01,  2.30s/it][A

Skipping NaN loss



 24%|██▎       | 24/102 [00:57<03:15,  2.50s/it][A
 25%|██▍       | 25/102 [01:00<03:05,  2.41s/it][A
 25%|██▌       | 26/102 [01:02<02:59,  2.36s/it][A
 26%|██▋       | 27/102 [01:04<02:54,  2.33s/it][A
 27%|██▋       | 28/102 [01:06<02:48,  2.27s/it][A

Skipping NaN loss



 28%|██▊       | 29/102 [01:09<03:06,  2.55s/it][A
 29%|██▉       | 30/102 [01:12<02:56,  2.45s/it][A
 30%|███       | 31/102 [01:14<02:49,  2.39s/it][A
 31%|███▏      | 32/102 [01:16<02:43,  2.34s/it][A
 32%|███▏      | 33/102 [01:18<02:38,  2.30s/it][A
 33%|███▎      | 34/102 [01:21<02:47,  2.47s/it][A

Skipping NaN loss



 34%|███▍      | 35/102 [01:24<02:45,  2.47s/it][A
 35%|███▌      | 36/102 [01:26<02:37,  2.39s/it][A
 36%|███▋      | 37/102 [01:28<02:32,  2.34s/it][A
 37%|███▋      | 38/102 [01:30<02:27,  2.30s/it][A
 38%|███▊      | 39/102 [01:33<02:26,  2.32s/it][A

Skipping NaN loss



 39%|███▉      | 40/102 [01:36<02:35,  2.51s/it][A
 40%|████      | 41/102 [01:38<02:28,  2.43s/it][A
 41%|████      | 42/102 [01:40<02:22,  2.37s/it][A
 42%|████▏     | 43/102 [01:42<02:15,  2.30s/it][A

Skipping NaN loss



 43%|████▎     | 44/102 [01:44<02:11,  2.26s/it][A

Skipping NaN loss



 44%|████▍     | 45/102 [01:48<02:25,  2.54s/it][A
 45%|████▌     | 46/102 [01:50<02:17,  2.45s/it][A
 46%|████▌     | 47/102 [01:52<02:09,  2.36s/it][A

Skipping NaN loss



 47%|████▋     | 48/102 [01:54<02:05,  2.32s/it][A

Skipping NaN loss



 48%|████▊     | 49/102 [01:56<02:00,  2.27s/it][A

Skipping NaN loss



 49%|████▉     | 50/102 [01:59<02:07,  2.46s/it][A
 50%|█████     | 51/102 [02:02<02:05,  2.45s/it][A

Skipping NaN loss



 51%|█████     | 52/102 [02:04<01:58,  2.36s/it][A

Skipping NaN loss



 52%|█████▏    | 53/102 [02:06<01:52,  2.30s/it][A

Skipping NaN loss



 53%|█████▎    | 54/102 [02:08<01:48,  2.26s/it][A

Skipping NaN loss



 54%|█████▍    | 55/102 [02:10<01:46,  2.26s/it][A

Skipping NaN loss



 55%|█████▍    | 56/102 [02:13<01:54,  2.49s/it][A

Skipping NaN loss



 56%|█████▌    | 57/102 [02:16<01:47,  2.39s/it][A

Skipping NaN loss



 57%|█████▋    | 58/102 [02:18<01:42,  2.32s/it][A

Skipping NaN loss



 58%|█████▊    | 59/102 [02:20<01:38,  2.28s/it][A

Skipping NaN loss



 59%|█████▉    | 60/102 [02:22<01:34,  2.26s/it][A
 60%|█████▉    | 61/102 [02:25<01:42,  2.50s/it][A
 61%|██████    | 62/102 [02:27<01:38,  2.45s/it][A
 62%|██████▏   | 63/102 [02:30<01:32,  2.38s/it][A
 63%|██████▎   | 64/102 [02:32<01:28,  2.33s/it][A
 64%|██████▎   | 65/102 [02:34<01:25,  2.30s/it][A
 65%|██████▍   | 66/102 [02:37<01:26,  2.41s/it][A
 66%|██████▌   | 67/102 [02:40<01:27,  2.50s/it][A
 67%|██████▋   | 68/102 [02:42<01:22,  2.42s/it][A
 68%|██████▊   | 69/102 [02:44<01:18,  2.37s/it][A
 69%|██████▊   | 70/102 [02:46<01:14,  2.32s/it][A
 70%|██████▉   | 71/102 [02:49<01:12,  2.33s/it][A
 71%|███████   | 72/102 [02:52<01:17,  2.57s/it][A
 72%|███████▏  | 73/102 [02:54<01:11,  2.48s/it][A
 73%|███████▎  | 74/102 [02:56<01:07,  2.42s/it][A
 74%|███████▎  | 75/102 [02:58<01:03,  2.36s/it][A
 75%|███████▍  | 76/102 [03:01<01:02,  2.39s/it][A
 75%|███████▌  | 77/102 [03:04<01:05,  2.63s/it][A
 76%|███████▋  | 78/102 [03:06<01:00,  2.51s/it][A
 77%|██████

Skipping NaN loss



 92%|█████████▏| 94/102 [03:45<00:19,  2.46s/it][A
 93%|█████████▎| 95/102 [03:47<00:16,  2.40s/it][A
 94%|█████████▍| 96/102 [03:49<00:14,  2.34s/it][A

Skipping NaN loss



 95%|█████████▌| 97/102 [03:51<00:11,  2.30s/it][A

Skipping NaN loss



 96%|█████████▌| 98/102 [03:55<00:10,  2.56s/it][A
 97%|█████████▋| 99/102 [03:57<00:07,  2.46s/it][A

Skipping NaN loss



 98%|█████████▊| 100/102 [03:59<00:04,  2.39s/it][A
 99%|█████████▉| 101/102 [04:01<00:02,  2.31s/it][A

Skipping NaN loss



100%|██████████| 102/102 [04:02<00:00,  2.38s/it]

Skipping NaN loss




  0%|          | 0/10 [00:03<?, ?it/s]
 50%|█████     | 1/2 [10:03<10:03, 603.44s/it]


IndexError: ignored