In [3]:
from google.colab import drive
drive.mount('/content/drive')
import os
os.chdir('/content/drive/My Drive/DL/Final_submission')
import torch.optim as optim
#from torch.optim import lr_scheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau
import os
import random
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['figure.figsize'] = [5, 5]
matplotlib.rcParams['figure.dpi'] = 200
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from data_helper import UnlabeledDataset, LabeledDataset
from helper import collate_fn, draw_box
from torchvision import models
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator

from torchvision.models.detection.backbone_utils import BackboneWithFPN
from torchvision.models import resnet
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops import misc as misc_nn_ops

cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if cuda else "cpu")


random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

# All the images are saved in image_folder
# All the labels are saved in the annotation_csv file
image_folder = 'data'
annotation_csv = 'data/annotation.csv'
# Set up your device 
cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if cuda else "cpu")
print(device)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
cuda:0


In [4]:
#backbone class from pirl jigsaw
class Representation_Generator(nn.Module):
    """Class that returns features for original image and fused features for image patches;
       its backbone is what we used for downstream task. """

    def __init__(self):
        super(Representation_Generator, self).__init__()
        #self.backbone = torch.nn.Sequential(*list(resnet50().children())[:-2]) 
        self.backbone = torchvision.models.segmentation.fcn_resnet50(pretrained=False).backbone
        #Should we add pyramid network in it as in faster_rcnn ???
        self.pool = nn.AdaptiveAvgPool2d((1,1)) #pool the spatial dimension to be 1*1
        self.head_f = nn.Linear(2048, 128) 
        self.head_g = nn.Linear(9*2048,128)

    def forward(self, images, patches = None):
        image_feat = self.pool(self.backbone(images)['out'])
        image_feat = image_feat.view(-1,2048) #batch size, 2048
        image_feat = self.head_f(image_feat) #batch size, 128

        if patches is not None:
            patches_feat = []
            for i, patch in enumerate(patches):
                
                patch_feat = self.pool(self.backbone(patch)['out']) #batch size, 2048, 1,1
                patch_feat = patch_feat.view(-1,2048) # batch_size,2048
                patches_feat.append(patch_feat)
         
            patches_feat = torch.cat(patches_feat, axis = 1) #batch size, 2048*9, 
            patches_feat = self.head_g(patches_feat)   #batch size, 128  

            return image_feat, patches_feat
        else:
            return image_feat


#phrase pirl backbone for object detection: add FPN and freeze BN
def resnet_fpn_backbone(backbone, freeze_bn = True):

    # copied the behaviour from faster_rcnn, freeze layer1
    for name, parameter in backbone.named_parameters():
        if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
            parameter.requires_grad_(False)

    return_layers = {'layer1': '0', 'layer2': '1', 'layer3': '2', 'layer4': '3'}

    in_channels_stage2 = 2048 // 8 #64 is resnet's inplanes
    in_channels_list = [
        in_channels_stage2,
        in_channels_stage2 * 2,
        in_channels_stage2 * 4,
        in_channels_stage2 * 8,
    ]
    out_channels = 256
    bb_fpn = BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels)
   
    def set_bn_eval(m): 
        
        #freeze the batchnorm2d layer: in object detection we are using small batch size, so we don't want to track batch statistics cause they are poor
        classname = m.__class__.__name__
        if classname.find('BatchNorm2d') != -1:
            m.eval()
    
    if freeze_bn:
        return bb_fpn.apply(set_bn_eval)
    else:
        return bb_fpn


#detection_backbone = resnet_fpn_backbone(torchvision.models.segmentation.fcn_resnet50(pretrained=False).backbone)


In [5]:
# model classes

# road segmentation's classifier
class FCNHead(nn.Sequential):
    def __init__(self, in_channels, num_classes):
        """in_channels: dim of input feature map after fusion
           """
        self.inter_channels = in_channels // 4
        self.layers = [
            nn.Conv2d(in_channels, self.inter_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(self.inter_channels),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Conv2d(self.inter_channels, num_classes, 1)
        ]

        super(FCNHead, self).__init__(*self.layers)

#road segmentation's model
class Road_Layout(nn.Module):
  def __init__(self, backbone, classifier):
        super(Road_Layout, self).__init__()
        #images, targets = self.transform(images, targets)
        self.backbone = backbone
        self.classifier = classifier
 

  def forward(self, images, targets=None):
        feats= []
        for view in range(6):
            feats.append(self.backbone(images[:,view,:])["out"])
            #concatenate feature map from all views
        #fused_feature = torch.cat(feats,dim = 1) #(batch_size, 6*fused channels, H, W )
        fused_feature = torch.mean(torch.stack(feats),dim = 0)
        x = self.classifier(fused_feature) #(batch size, num_classes, H,W)
        x = F.interpolate(x, size=(800,800), mode='bilinear', align_corners=False)

        return x #(batch_size, num_classes, 800,800)


# object detection  - generate top_down layer
class Fusion_Layer(nn.Module):
    """Model to generate  800 * 800 size road map / Convert to Bird Eye View;
       road_model_feat is the feature map output from the road_model, assumed to have (h,w) the same as backbone output feat dim;
       mean?? project with camera intrinsics??"""
    def __init__(self, backbone, feature_channels, road_model_feat = None):
        super(Fusion_Layer, self).__init__()    
        tot_channels = feature_channels 
        self.road_model_feat = road_model_feat
        if road_model_feat is not None:
            tot_channels += road_model_feat.size()[1]
        
        #for mapping back to 800*800
        self.backbone = backbone
        self.relu =  nn.ReLU()
        self.deconv1 = nn.ConvTranspose2d(tot_channels, tot_channels, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.bn1 = nn.BatchNorm2d(tot_channels)
        self.deconv2 = nn.ConvTranspose2d(tot_channels, tot_channels//4, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.bn2 = nn.BatchNorm2d(tot_channels//4)
        self.deconv3 = nn.ConvTranspose2d(tot_channels//4, tot_channels//16, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.bn3 = nn.BatchNorm2d(tot_channels//16)
        self.deconv4 = nn.ConvTranspose2d(tot_channels//16, tot_channels//16, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.bn4 = nn.BatchNorm2d(tot_channels//16)
        self.deconv5 = nn.ConvTranspose2d(tot_channels//16, 3, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.bn5 = nn.BatchNorm2d(3)

    def forward(self, images):
        feats= []
        for view in range(6):
            feats.append(self.backbone(images[:,view,:])["out"])
        fused_feature = torch.mean(torch.stack(feats),dim = 0)
        # fused_feature = torch.cat(feats,dim = 1)
        if self.road_model_feat is not None:
            fused_feature = toch.cat([fused_feature,self.road_model_feat],dim = 1)

        #transform into (batch_size, channels, 800,800)
        x1 = self.bn1(self.relu(self.deconv1(fused_feature))) # (H,W) -> 2*(H,W)
        x2 = self.bn2(self.relu(self.deconv2(x1)))
        x3 = self.bn3(self.relu(self.deconv3(x2)))
        x4 = self.bn4(self.relu(self.deconv4(x3)))
        x5 = self.bn5(self.relu(self.deconv5(x4)))

        return x5

#detection model
class Box_Model(nn.Module):
    def __init__(self, fuse_layer, detect_model):
        super(Box_Model, self).__init__()
        self.fuse_layer = fuse_layer
        self.detect_model  = detect_model

    def transform_target(self,targets):
        res = [] #targets should be a list of dictionaries with key "boxes" and "labels"
        for t in targets:
            N = t['category'].size(0) #number of boxes
            bbox = torch.zeros(N,4)
            for n in range(N):
                nth_box = 10*t['bounding_box'][n] #shape (2 ,4) multiply 10 because the original value is in meters!
                xmax = torch.max(nth_box[0,:]) 
                xmin = torch.min(nth_box[0,:])
                ymax = torch.max(nth_box[1,:]) 
                ymin = torch.min(nth_box[1,:])
                bbox[n] = torch.tensor([xmin,ymin,xmax,ymax]) + 400
            res.append({"boxes":bbox.to(device),"labels":t['category'].to(device)})
        return res 

    def get_output(self,preds):
        pred_box = []
        pred_label = []
        for p in preds: 
            nbox = p['boxes'].size(0)
            res_box = torch.zeros(nbox, 2, 4)
            res_label = torch.zeros(nbox)
            for n in range(nbox):
                xmin, ymin, xmax, ymax = p['boxes'][n] #in pixel level
                res_box[n] = torch.tensor([[xmax, xmax, xmin, xmin],[ymax,ymin,ymax, ymin]])
                res_label[n] = p['labels'][n]
            res_box = (res_box - 400)/10 #the unit should be meter instead of pixels
            pred_box.append(res_box)
            pred_label.append(res_label)
        return {"boxes":tuple(pred_box), "labels":tuple(pred_label)}


    def forward(self, images, targets  = None):

        top_down = self.fuse_layer(images) #(batch_size, 3, 800, 800)
        top_down  = [i for i in top_down]
        if self.training:
            self.detect_model.train() 
            targets = self.transform_target(targets)
            output = self.detect_model(top_down,targets) #loss_dict
        else:
            preds = self.detect_model(top_down)#list of dictionary of keys 'boxes', 'labels', 'scores'
            #need to transform predicted boxes coordinates, a torch tensor of size (num_boxes, 2, 4)
            output = self.get_output(preds)

        return output #at eval mode, output is a dictionary: "boxes": a tuple of tensors of size (num_boxes, 2, 4), "labels": a tuple of tensor (num_boxes)




In [6]:
#model loader
##get model to train: model_loader.road_model and model_loader.box_model
class ModelLoader():
    def __init__(self, model_file='models_final.py', test = False):
        # 1. create the model object
        # 2. load your state_dict
        # 3. call cuda()
        # road segmentation

        rep_net1 = Representation_Generator()
        rep_net2 = Representation_Generator()
        rep_net1.load_state_dict(torch.load('./model/pirl_jigsaw/rep_net3.pth',map_location=torch.device('cpu')))
        rep_net2.load_state_dict(torch.load('./model/pirl_jigsaw/rep_net3.pth',map_location=torch.device('cpu')))
        print("Load pretrained backbone successfully!")
        #allow reconstruction of topdown view and object detection to use separate backbone
        self.reconstruct_backbone = rep_net1.backbone 
        self.detection_backbone = resnet_fpn_backbone(rep_net2.backbone) #make it conformed with faster_rcnn

        #self.backbone = models.segmentation.fcn_resnet50(num_classes = 2, pretrained=False).backbone
        self.classifier = FCNHead(2048, 2)
        self.road_model = Road_Layout(self.reconstruct_backbone, self.classifier)

        # object detection
        self.num_classes = 10
        self.anchor_generator = AnchorGenerator( ((8,),(16,),(32,), (64,), (128,)),
                                        ((0.5, 1.0, 2.0),) *5)
        self.roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0', '1', '2', '3'], #cooresponds to 4 layers in resnet
                                                        output_size=7,
                                                        sampling_ratio=2)

        ## put the pieces together inside a FasterRCNN model, input needs to be (batch_size, channels, H, W)
        self.detect_model = FasterRCNN(self.detection_backbone,
                        num_classes = self.num_classes,
                        rpn_anchor_generator = self.anchor_generator,
                        box_roi_pool = self.roi_pooler)

        #self.reconstruct_bone = torchvision.models.segmentation.fcn_resnet50(pretrained=False).backbone.to(device)
        self.FL = Fusion_Layer(self.reconstruct_backbone, 2048, road_model_feat = None)
        self.box_model = Box_Model(self.FL, self.detect_model)

        #load model state_dict
        if test:
            self.road_model.load_state_dict(torch.load('./model/road_seg_BH/epoch_16_avg_val_loss_0.04860_avg_ts_0.91775_lr_0.0001000000_2.pth',map_location=torch.device('cpu')))
            self.road_model.eval()

            self.box_model.load_state_dict(torch.load('./model/box_detect/fintune_epoch_2_avg_ats_0.01862_lr_0.0001000000.pth',map_location=torch.device('cpu')))
            self.box_model.eval()

    def get_bounding_boxes(self, samples):
        # samples is a cuda tensor with size [batch_size, 6, 3, 256, 306]
        # return a tuple with size 'batch_size' and each element is a cuda tensor [N, 2, 4]
        # where N is the number of object
        with torch.no_grad():
            output = self.box_model(samples)
            pred_boxes = output['boxes']
            #Future work: deal with outputs with 0 objects
            if pred_boxes[0].numel() == 0:
                return tuple(torch.zeros(len(samples),1,2,4))
        return pred_boxes

    def get_binary_road_map(self, samples):
        # samples is a cuda tensor with size [batch_size, 6, 3, 256, 306]
        # return a cuda tensor with size [batch_size, 800, 800]
        with torch.no_grad():
            output = self.road_model(samples)

        return torch.argmax(output, dim=1)


In [7]:
!ls ./model/pirl_jigsaw/

rep_net2.pth  rep_net3.pth


In [8]:
model_loader = ModelLoader()

Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /root/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth


HBox(children=(FloatProgress(value=0.0, max=102502400.0), HTML(value='')))


Load pretrained backbone successfully!


In [9]:
def num_of_trained_parameters(model):
    for x in model.named_parameters():
        print (x[0],x[1].size())
    #return sum(param.numel() for param in model.parameters() if param.requires_grad)
    return sum(param.numel() for param in model.parameters() if param.requires_grad)
num_of_trained_parameters(model_loader.box_model)

fuse_layer.backbone.conv1.weight torch.Size([64, 3, 7, 7])
fuse_layer.backbone.bn1.weight torch.Size([64])
fuse_layer.backbone.bn1.bias torch.Size([64])
fuse_layer.backbone.layer1.0.conv1.weight torch.Size([64, 64, 1, 1])
fuse_layer.backbone.layer1.0.bn1.weight torch.Size([64])
fuse_layer.backbone.layer1.0.bn1.bias torch.Size([64])
fuse_layer.backbone.layer1.0.conv2.weight torch.Size([64, 64, 3, 3])
fuse_layer.backbone.layer1.0.bn2.weight torch.Size([64])
fuse_layer.backbone.layer1.0.bn2.bias torch.Size([64])
fuse_layer.backbone.layer1.0.conv3.weight torch.Size([256, 64, 1, 1])
fuse_layer.backbone.layer1.0.bn3.weight torch.Size([256])
fuse_layer.backbone.layer1.0.bn3.bias torch.Size([256])
fuse_layer.backbone.layer1.0.downsample.0.weight torch.Size([256, 64, 1, 1])
fuse_layer.backbone.layer1.0.downsample.1.weight torch.Size([256])
fuse_layer.backbone.layer1.0.downsample.1.bias torch.Size([256])
fuse_layer.backbone.layer1.1.conv1.weight torch.Size([64, 256, 1, 1])
fuse_layer.backbone.la

112611082

In [10]:
#because pirl is trained with DataParallel, we need to modify the key of the state_dict
def preprocess_state_dict(file_path):
    state_dict = torch.load(file_path)
    from collections import OrderedDict
    new_state_dict = OrderedDict()

    for k, v in state_dict.items():
        if 'module'  in k:
            k = k[7:]
        new_state_dict[k]=v
    torch.save(new_state_dict,file_path)

preprocess_state_dict('./model/pirl_jigsaw/rep_net3.pth')

In [11]:
#split train test
args = {'bs': 2}

labeled_val_scene_index = np.arange(128, 134)
labeld_train_scence_index = np.arange(106,128)

from helper import collate_fn
def get_transform_task1():
    transform = torchvision.transforms.Compose([torchvision.transforms.Resize((198,198)),
                 torchvision.transforms.ToTensor(),
                 torchvision.transforms.Normalize(mean = [0.485, 0.456, 0.406],
                                                  std = [0.229, 0.224, 0.225])])
    return transform

def get_transform_task2():
    transform = torchvision.transforms.Compose([torchvision.transforms.Resize((198,198)),
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize(mean = [0.485, 0.456, 0.406],
                                                    std = [0.229, 0.224, 0.225])])
    return transform 
trainset_task1 = LabeledDataset(
    image_folder=image_folder,
    annotation_file=annotation_csv,
    scene_index=labeld_train_scence_index,
    transform=get_transform_task1(),
    extra_info=False
    )


trainset_task2 = LabeledDataset(
    image_folder=image_folder,
    annotation_file=annotation_csv,
    scene_index=labeld_train_scence_index,
    transform=get_transform_task2(),
    extra_info=False
    )


trainloader_task1 = torch.utils.data.DataLoader(
    trainset_task1,
    batch_size=args['bs'],
    shuffle=True,
    num_workers=2,
    collate_fn=collate_fn
    )


trainloader_task2 = torch.utils.data.DataLoader(
    trainset_task2,
    batch_size=args['bs'],
    shuffle=True,
    num_workers=2,
    collate_fn=collate_fn
    )


# For bounding boxes task
labeled_trainset_task1 = LabeledDataset(
    image_folder=image_folder,
    annotation_file=annotation_csv,
    scene_index=labeled_val_scene_index,
    transform=get_transform_task1(),
    extra_info=False
    )
dataloader_task1 = torch.utils.data.DataLoader(
    labeled_trainset_task1,
    batch_size=1,
    shuffle=False,
    num_workers=2
    )
# For road map task
labeled_trainset_task2 = LabeledDataset(
    image_folder=image_folder,
    annotation_file=annotation_csv,
    scene_index=labeled_val_scene_index,
    transform=get_transform_task2(),
    extra_info=False
    )
dataloader_task2 = torch.utils.data.DataLoader(
    labeled_trainset_task2,
    batch_size=1,
    shuffle=False,
    num_workers=2
    )



In [12]:
!ls model

box_detect	     pirl_jigsaw  road_seg_BH	road_seg_SSL_new
model_v1_state_dict  road_seg	  road_seg_SSL


In [13]:
def train_road_model(road_model, optimizer, criterion,  trainloader, dataloader_task2, epoch, best_TS, log_interval = 50):
    road_model.train()
    running_loss = 0

    for itr, (images, bbox, road_image ) in enumerate(trainloader):      
        images = torch.stack(images).to(device)
        road_image = torch.stack(road_image).to(device)
     
        optimizer.zero_grad()
            
        output = road_model(images)
        loss = criterion(output, road_image*1)
        loss.backward()
        optimizer.step()
        running_loss += loss

        if itr > 0 and itr % log_interval == 0:
            print("""Train Epoch: {} [{}/{} ({:.0f}%)]\t 
                    Current Loss: {:.6f}, Average Loss: {:.6f}""".format(
                epoch, itr * len(images), len(trainloader.dataset),
                100. * itr / len(trainloader), loss.item(), running_loss/log_interval))
            running_loss =  0
    #val once

    road_model.eval()
    total = 0
    total_ts_road_map = 0
    with torch.no_grad():
        for i, data in enumerate(dataloader_task2):
            total += 1
            sample, target, road_image = data
            sample = sample.cuda()

            predicted_road_map = model_loader.get_binary_road_map(sample).cpu()
            ts_road_map = helper.compute_ts_road_map(predicted_road_map, road_image)
            total_ts_road_map += ts_road_map

    road_map_score = total_ts_road_map / total
    print(f'Road Map Score: {road_map_score:.4}')
    if epoch >= 1 and road_map_score > best_TS:
        best_TS = road_map_score
        snapshot_name = 'fintune_epoch_%.0f_raod_map_score_%.5f_lr_%.10f' % (
                epoch, road_map_score, optimizer.param_groups[0]['lr']
            )
        torch.save(road_model.state_dict(), './model/road_seg/' + snapshot_name + '.pth')
        torch.save(optimizer.state_dict(), './model/road_seg/opt_' + snapshot_name + '.pth')

    return road_map_score, best_TS

best_TS = 0.1
lr         = 1e-5
momentum   = 0
w_decay    = 0.01
log_interval = 100

criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.AdamW(model_loader.road_model.parameters(), lr= lr, betas=(0.9,0.999), eps=1e-08, weight_decay=w_decay, amsgrad=False)
#optimizer = optim.RMSprop(model_loader.road_model.parameters(), lr=lr, momentum=momentum, weight_decay=w_decay)
scheduler = ReduceLROnPlateau(optimizer,'max', patience= 2, min_lr=1e-10, verbose=True) #use it on TS


for epoch in range(30):    
    road_map_score, best_TS = train_road_model(model_loader.road_model.to(device), optimizer, criterion, trainloader_task2, dataloader_task2, epoch, best_TS)
    scheduler.step(road_map_score)  




                    Current Loss: 0.505769, Average Loss: 0.476861


KeyboardInterrupt: ignored

In [16]:
#-----------------------------------------------------------------------------------------------------------------
def train_box_model(box_model, optimizer, trainloader, dataloader_task1, epoch,best_ats, log_interval = 100):
    box_model.train()
    running_loss = 0

    for itr, (images, bbox, road_image ) in enumerate(trainloader): 
        images = torch.stack(images).to(device)
        loss_dict = box_model(images,bbox)
        losses = sum(loss for loss in loss_dict.values()) 
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        #include lr scheduler
        running_loss += losses

        if itr > 0 and itr % log_interval == 0:
            print("""Train Epoch: {} [{}/{} ({:.0f}%)]\t 
                    Average Loss: {:.6f}""".format(
                epoch, itr * len(images), len(trainloader.dataset),
                100. * itr / len(trainloader),  running_loss/log_interval))
            running_loss =  0

    # #val once
    total = 0
    total_ats_bounding_boxes = 0
    box_model.eval()

    with torch.no_grad():
        for i, data in enumerate(dataloader_task1):
            total += 1
            sample, target, road_image = data
            sample = sample.cuda()
            output = box_model(sample)
            pred_boxes = output['boxes'][0].cpu()
                #Future work: deal with outputs with 0 objects
            if pred_boxes[0].numel() == 0:
                pred_boxes = tuple(torch.zeros(len(samples),1,2,4))

            ats_bounding_boxes = helper.compute_ats_bounding_boxes(pred_boxes, target['bounding_box'][0])
            total_ats_bounding_boxes += ats_bounding_boxes
    
        avg_ats = total_ats_bounding_boxes/total
        # Print loss (uncomment lines below once implemented)
        print('\nValidataion set: Average ATS: {:.4f}\n'.format(avg_ats))
        #save model if they are better
        if epoch >= 1 and avg_ats > best_ats:
            best_ats = avg_ats
            snapshot_name = 'fintune_epoch_%.0f_avg_ats_%.5f_lr_%.10f' % (
                    epoch, avg_ats, optimizer.param_groups[0]['lr']
                )
            torch.save(box_model.state_dict(), './model/box_detect/' + snapshot_name + '.pth')
            torch.save(optimizer.state_dict(), './model/box_detect/opt_' + snapshot_name + '.pth')

    return avg_ats, best_ats


lr         = 1e-5
momentum   = 0.9
w_decay    = 0.0005
best_ats = 0.01862

optimizer = torch.optim.SGD(model_loader.box_model.parameters(), lr=lr,
                           momentum=momentum, weight_decay=w_decay)
optimizer.load_state_dict(torch.load('./model/box_detect/opt_fintune_epoch_2_avg_ats_0.01862_lr_0.0001000000.pth'))
for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.cuda()

scheduler = ReduceLROnPlateau(optimizer,'max', patience=2, min_lr=1e-10, verbose=True) 

for epoch in range(3, 208):    
    avg_ats, best_ats = train_box_model(model_loader.box_model.to(device), optimizer, trainloader_task1, dataloader_task1, epoch, best_ats)
    scheduler.step(avg_ats)  


RuntimeError: ignored