In [13]:
import torch
import torchvision
import math
import pandas as pd
import numpy as np
import gc
from tqdm import tqdm_notebook
from torch.utils.tensorboard import SummaryWriter


from GWD_data import WheatDataset
from metric import calculate_image_precision
from utils import collate_fn, plot_boxes, get_model_name, format_prediction_string, remove_blanks, is_contain_blanks
from config import config


In [14]:
conf = config()
torch.random.manual_seed(5)
writer = SummaryWriter("runs/May11_23-57-42_DESKTOP-ELPLUSQ") #


WD_Train = WheatDataset(conf)
WD_Valid = WheatDataset(conf, train=False)
WD_Train_Loader = torch.utils.data.DataLoader(WD_Train, batch_size=conf.BATH_SIZE, shuffle=True, collate_fn=collate_fn)
WD_Valid_Loader = torch.utils.data.DataLoader(WD_Valid, batch_size=conf.BATH_SIZE, shuffle=True, collate_fn=collate_fn)

In [15]:
import torch.nn.functional as F
from torch import nn

# model_path = 'model/GWD_EPOCH_29_SCORE_0.0000_LOSS_191.8370.pt'
# GWD_Model = torch.load(model_path) if model_path else torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False, progress=True, num_classes=2, pretrained_backbone=True)

EPOCH = 0 #int(model_path.split('/')[-1].split('_')[2])+1 if model_path else 0

GWD_detr_Model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)
# model_dict_path = './model_dicts/detr-r50-e632da11.pth'
# GWD_detr_Model.load_state_dict(torch.load(model_dict_path))

print(GWD_detr_Model.bbox_embed)
print(GWD_detr_Model.class_embed)

num_classes=1
hidden_dim=256

GWD_detr_Model.class_embed = nn.Linear(hidden_dim, num_classes + 1)

print(GWD_detr_Model.bbox_embed)
print(GWD_detr_Model.class_embed)

GWD_detr_Model.to(conf.DEVICE)
params = [p for p in GWD_detr_Model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(params, lr=0.0001, weight_decay=0.0001)      #These very the default values in source code

Using cache found in C:\Users\mihir/.cache\torch\hub\facebookresearch_detr_master


MLP(
  (layers): ModuleList(
    (0): Linear(in_features=256, out_features=256, bias=True)
    (1): Linear(in_features=256, out_features=256, bias=True)
    (2): Linear(in_features=256, out_features=4, bias=True)
  )
)
Linear(in_features=256, out_features=92, bias=True)
MLP(
  (layers): ModuleList(
    (0): Linear(in_features=256, out_features=256, bias=True)
    (1): Linear(in_features=256, out_features=256, bias=True)
    (2): Linear(in_features=256, out_features=4, bias=True)
  )
)
Linear(in_features=256, out_features=2, bias=True)


In [16]:
from detr.models.detr import SetCriterion
from detr.models.matcher import HungarianMatcher

#These are the defualt values in the source code
bbox_loss_coef=5.0 
giou_loss_coef=2.0
eos_coef=0.1
set_cost_class=1.0
set_cost_bbox=5.0
set_cost_giou=2.0


matcher = HungarianMatcher(cost_class=set_cost_class, cost_bbox=set_cost_bbox, cost_giou=set_cost_giou)
weight_dict = {'loss_ce': 1, 'loss_bbox': bbox_loss_coef}
weight_dict['loss_giou'] = giou_loss_coef
losses = ['labels', 'boxes', 'cardinality']

criterion = SetCriterion(num_classes, matcher=matcher, weight_dict=weight_dict,eos_coef=eos_coef, losses=losses)
criterion.to(conf.DEVICE)

SetCriterion(
  (matcher): HungarianMatcher()
)

In [20]:
_iter=len(WD_Train_Loader)*(EPOCH)
clip_max_norm=0.1  #Default value in their code

for i in tqdm_notebook(range(EPOCH, 50, 1)):
    
    _epoch_loss=0
    _ = GWD_detr_Model.train()
    _ = criterion.train()
    
    for images, targets in tqdm_notebook(WD_Train_Loader):

        images = [torch.tensor(image, dtype = torch.float32).to(conf.DEVICE) for image in images]
        targets = [{k: torch.tensor(v).to(conf.DEVICE) for k, v in target.items()} for target in targets]
        
        #Main fwd pass and loss calc  
        outputs = GWD_detr_Model(images)
        loss_dict = criterion(outputs, targets)
        weight_dict = criterion.weight_dict
        summed_loss_value = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)


        if math.isfinite(summed_loss_value):
            optimizer.zero_grad()
            summed_loss_value.backward()
            if clip_max_norm > 0:
                torch.nn.utils.clip_grad_norm_(GWD_detr_Model.parameters(), clip_max_norm)
            optimizer.step()
            _epoch_loss+=summed_loss_value
            
            if(_iter%1 == 0):
                print("".join([k+":"+str(v.data.cpu().numpy())+", " for k,v in loss_dict.items()]))

        else:
            print('Loss is undefined:',summed_loss_value,'   skipping BackProp for step no:',_iter)
            print(loss_dict)


        writer.add_scalar('Running Loss/Summed', summed_loss_value, _iter)
        writer.add_scalar('Running Loss/CE', loss_dict['loss_ce'].item(), _iter)
        writer.add_scalar('Running Loss/Classifier', loss_dict['class_error'].item(), _iter)
        writer.add_scalar('Running Loss/Box_Regress', loss_dict['loss_bbox'].item(), _iter)
        writer.add_scalar('Running Loss/loss_giou', loss_dict['loss_giou'].item(), _iter)
        writer.add_scalar('Running Loss/cardinality_error', loss_dict['cardinality_error'].item(), _iter)
         
        _iter+=1
        
    print('Saving model at epoch '+str(i)+', step '+str(_iter))
    torch.save(GWD_detr_Model, 'model-detr/'+get_model_name(i, 0, _epoch_loss))
    torch.cuda.empty_cache()
    gc.collect()
    break

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for i in tqdm_notebook(range(EPOCH, 50, 1)):


HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, targets in tqdm_notebook(WD_Train_Loader):


HBox(children=(FloatProgress(value=0.0, max=1349.0), HTML(value='')))

loss_ce:0.22639349, class_error:0.0, loss_bbox:2070.5508, loss_giou:1.4116917, cardinality_error:64.5, 
loss_ce:0.14303052, class_error:0.0, loss_bbox:1959.4651, loss_giou:1.3552854, cardinality_error:60.0, 
loss_ce:0.10310736, class_error:0.0, loss_bbox:2084.1558, loss_giou:1.4054142, cardinality_error:54.5, 
loss_ce:0.07172582, class_error:-7.6293945e-06, loss_bbox:1753.5231, loss_giou:1.2988861, cardinality_error:22.5, 




KeyboardInterrupt: 