In [1]:
from torch.utils.data import DataLoader
import gc
import torch
from tqdm.notebook import tqdm
import math

from detr.config import config
from augmentations.aug import get_augmentor
from detr.utils.data import WheatDataset, collate_fn
#from fasterrcnn.utils.training_handler import get_training_handler

In [5]:
conf = config()
augmentor = get_augmentor()

#conf.DATA_PATH = os.path.join('/kaggle','input', 'global-wheat-detection')
conf.BATCH_SIZE = 2
DEVICE = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu')
#conf.SPLIT = 0.8

WD_Train = WheatDataset(conf, 
                  is_train=True, 
                  augmentation=True,
                  normalize=True,
                  augmentor=augmentor, 
                  random_seed=0)

In [6]:
WD_Train_Loader = DataLoader(WD_Train, batch_size=conf.BATCH_SIZE, shuffle=True, collate_fn=collate_fn)


In [7]:
import torch.nn.functional as F
from torch import nn

EPOCH = 0

GWD_detr_Model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=False, num_classes=1)


checkpoint = torch.hub.load_state_dict_from_url(
            url='https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth',
            #map_location='cpu',
            check_hash=True)
del checkpoint["model"]["class_embed.weight"]
del checkpoint["model"]["class_embed.bias"]
GWD_detr_Model.load_state_dict(checkpoint["model"], strict=False)

print(GWD_detr_Model.bbox_embed)
print(GWD_detr_Model.class_embed)

GWD_detr_Model.to(conf.DEVICE)
#print(GWD_detr_Model)
params = [p for p in GWD_detr_Model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.00001, weight_decay=0.0001)  

Using cache found in C:\Users\jay/.cache\torch\hub\facebookresearch_detr_master


MLP(
  (layers): ModuleList(
    (0): Linear(in_features=256, out_features=256, bias=True)
    (1): Linear(in_features=256, out_features=256, bias=True)
    (2): Linear(in_features=256, out_features=4, bias=True)
  )
)
Linear(in_features=256, out_features=2, bias=True)


In [8]:
from detr.detr.models.detr import SetCriterion
from detr.detr.models.matcher import HungarianMatcher

num_classes=1

#These are the default values in the source code
bbox_loss_coef=5.0 
giou_loss_coef=2.0
eos_coef=0.1
set_cost_class=1.0
set_cost_bbox=5.0
set_cost_giou=2.0


matcher = HungarianMatcher(cost_class=set_cost_class, cost_bbox=set_cost_bbox, cost_giou=set_cost_giou)
weight_dict = {'loss_ce': 1, 'loss_bbox': bbox_loss_coef}
weight_dict['loss_giou'] = giou_loss_coef
losses = ['labels', 'boxes', 'cardinality']

criterion = SetCriterion(num_classes, matcher=matcher, weight_dict=weight_dict,eos_coef=eos_coef, losses=losses)
criterion.to(conf.DEVICE)

SetCriterion(
  (matcher): HungarianMatcher()
)

In [9]:
_iter=len(WD_Train_Loader)*(EPOCH)
clip_max_norm=0.1  #Default value in their code

_ = GWD_detr_Model.train()
_ = criterion.train()

for i in tqdm(range(EPOCH, 50, 1)):
    ep_loss_ce=0
    ep_class_error=0
    ep_loss_bbox=0 
    ep_loss_giou=0 
    ep_cardinality_error=0
    
    
    for images, targets in tqdm(WD_Train_Loader):

        images = [torch.tensor(image, dtype = torch.float32).to(conf.DEVICE) for image in images]
        targets = [{k: torch.tensor(v).to(conf.DEVICE) for k, v in target.items()} for target in targets]
        #print(targets[0]['labels'])
        
        #Main fwd pass and loss calc  
        outputs = GWD_detr_Model(images)
        
        #print(outputs['pred_logits'][0])

        loss_dict = criterion(outputs, targets)
        weight_dict = criterion.weight_dict
        summed_loss_value = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)

        if math.isfinite(summed_loss_value):
            optimizer.zero_grad()
            summed_loss_value.backward()
            if clip_max_norm > 0:
                torch.nn.utils.clip_grad_norm_(GWD_detr_Model.parameters(), clip_max_norm)
            optimizer.step()
            
            ep_loss_ce+= loss_dict['loss_ce'].item()
            ep_class_error+= loss_dict['class_error'].item()
            ep_loss_bbox+= loss_dict['loss_bbox'].item()
            ep_loss_giou+= loss_dict['loss_giou'].item()
            ep_cardinality_error+= loss_dict['cardinality_error'].item()
            
            if(_iter%50 == 0):
                print("".join([k+":"+str(v.data.cpu().numpy())+", " for k,v in loss_dict.items()]))
    

        else:
            print('Loss is undefined:',summed_loss_value,'   skipping BackProp for step no:',_iter)
            print(loss_dict)
            
        _iter+=1
       
    #Divide by number of batches
    ep_loss_ce = float(ep_loss_ce/len(WD_Train_Loader))  
    ep_class_error = float(ep_class_error/len(WD_Train_Loader))
    ep_loss_bbox = float(ep_loss_bbox/len(WD_Train_Loader))
    ep_loss_giou = float(ep_loss_giou/len(WD_Train_Loader))
    ep_cardinality_error = float(ep_cardinality_error/len(WD_Train_Loader))
    _epoch_loss = ep_loss_ce+ep_class_error+ep_loss_bbox+ep_loss_giou+ep_cardinality_error
    
    print('Saving model at epoch '+str(i)+', step '+str(_iter))
    print("5 Avg Losses: {0}, {1}, {2}, {3}, {4}".format(ep_loss_ce, ep_class_error, ep_loss_bbox, ep_loss_giou, ep_cardinality_error))
    validation_score=0
    torch.save(GWD_detr_Model.state_dict(), "./GWD_DETR_SD_Epoch_{0}_Score_{1:.4f}_EpLoss_{2:.4f}.pt".format(i, validation_score, _epoch_loss))
    torch.cuda.empty_cache()
    gc.collect()


HBox(children=(IntProgress(value=0, max=50), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1349), HTML(value='')))

loss_ce:0.3324999, class_error:1.0416718, loss_bbox:0.36965245, loss_giou:1.2597151, cardinality_error:47.0, 


KeyboardInterrupt: 