In [1]:
import torch
import torchvision
import math
import pandas as pd
import numpy as np
import gc
from tqdm import tqdm_notebook
from torch.utils.tensorboard import SummaryWriter


from GWD_data import WheatDataset
from metric import calculate_image_precision
from utils import collate_fn, plot_boxes, get_model_name, format_prediction_string, remove_blanks, is_contain_blanks
from config import config


In [2]:
conf = config()
torch.random.manual_seed(5)
writer = SummaryWriter("runs/May11_23-57-42_DESKTOP-ELPLUSQ") #


WD_Train = WheatDataset(conf)
WD_Valid = WheatDataset(conf, train=False)
WD_Train_Loader = torch.utils.data.DataLoader(WD_Train, batch_size=conf.BATH_SIZE, shuffle=True, collate_fn=collate_fn)
WD_Valid_Loader = torch.utils.data.DataLoader(WD_Valid, batch_size=conf.BATH_SIZE, shuffle=True, collate_fn=collate_fn)

In [17]:
import torch.nn.functional as F
from torch import nn

# model_path = 'model/GWD_EPOCH_29_SCORE_0.0000_LOSS_191.8370.pt'
# GWD_Model = torch.load(model_path) if model_path else torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False, progress=True, num_classes=2, pretrained_backbone=True)

EPOCH = 0 #int(model_path.split('/')[-1].split('_')[2])+1 if model_path else 0

GWD_detr_Model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)
# model_dict_path = './model_dicts/detr-r50-e632da11.pth'
# GWD_detr_Model.load_state_dict(torch.load(model_dict_path))

print(GWD_detr_Model.bbox_embed)
print(GWD_detr_Model.class_embed)

num_classes=1
hidden_dim=256

GWD_detr_Model.class_embed = nn.Linear(hidden_dim, num_classes + 1)

print(GWD_detr_Model.bbox_embed)
print(GWD_detr_Model.class_embed)

GWD_detr_Model.to(conf.DEVICE)
params = [p for p in GWD_detr_Model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(params, lr=0.0005, weight_decay=0.0005)

Using cache found in C:\Users\mihir/.cache\torch\hub\facebookresearch_detr_master


MLP(
  (layers): ModuleList(
    (0): Linear(in_features=256, out_features=256, bias=True)
    (1): Linear(in_features=256, out_features=256, bias=True)
    (2): Linear(in_features=256, out_features=4, bias=True)
  )
)
Linear(in_features=256, out_features=92, bias=True)
MLP(
  (layers): ModuleList(
    (0): Linear(in_features=256, out_features=256, bias=True)
    (1): Linear(in_features=256, out_features=256, bias=True)
    (2): Linear(in_features=256, out_features=4, bias=True)
  )
)
Linear(in_features=256, out_features=2, bias=True)


In [26]:
from detr.models.detr import SetCriterion
from detr.models.matcher import HungarianMatcher

#These are the defualt values in the source code
bbox_loss_coef=5.0 
giou_loss_coef=2.0
eos_coef=0.1
set_cost_class=1.0
set_cost_bbox=5.0
set_cost_giou=2.0


matcher = HungarianMatcher(cost_class=set_cost_class, cost_bbox=set_cost_bbox, cost_giou=set_cost_giou)
weight_dict = {'loss_ce': 1, 'loss_bbox': bbox_loss_coef}
weight_dict['loss_giou'] = giou_loss_coef
losses = ['labels', 'boxes', 'cardinality']

criterion = SetCriterion(num_classes, matcher=matcher, weight_dict=weight_dict,eos_coef=eos_coef, losses=losses)
criterion.to(conf.DEVICE)

SetCriterion(
  (matcher): HungarianMatcher()
)

In [28]:
_iter=len(WD_Train_Loader)*(EPOCH)

for i in tqdm_notebook(range(EPOCH, 50, 1)):
    
    _epoch_loss=0
    _ = GWD_detr_Model.train()
    _ = criterion.train()
    for images, targets in tqdm_notebook(WD_Train_Loader):
        
        #if not is_contain_blanks(targets):
        try:
            images = [torch.tensor(image, dtype = torch.float32).to(conf.DEVICE) for image in images]
            targets = [{k: torch.tensor(v).to(conf.DEVICE) for k, v in target.items()} for target in targets]

            #loss_dictionary = GWD_Model(images,targets)
            #summed_loss = sum(loss for loss in loss_dictionary.values())
            #summed_loss_value = summed_loss.item()
            
            outputs = GWD_detr_Model(images)
            loss_dict = criterion(outputs, targets)
            weight_dict = criterion.weight_dict
            summed_loss_value = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)


            if math.isfinite(summed_loss_value):
                optimizer.zero_grad()
                summed_loss.backward()
                optimizer.step()
                _epoch_loss+=summed_loss_value

            else:
                print('Loss is undefined:',summed_loss_value,'   skipping BackProp for step no:',_iter)
                print(loss_dictionary)


            writer.add_scalar('Running Loss/Summed', summed_loss_value, _iter)
            writer.add_scalar('Running Loss/Classifier', loss_dictionary['loss_classifier'].item(), _iter)
            writer.add_scalar('Running Loss/Box_Regress', loss_dictionary['loss_box_reg'].item(), _iter)
            writer.add_scalar('Running Loss/Objectness', loss_dictionary['loss_objectness'].item(), _iter)
            writer.add_scalar('Running Loss/RPN_Box_regress', loss_dictionary['loss_rpn_box_reg'].item(), _iter)
            
        except:
            print('Some weird shape error due to blank images, but handeled')
        
        _iter+=1
        
    torch.save(GWD_detr_Model, 'model/'+get_model_name(i, 0, _epoch_loss))
    torch.cuda.empty_cache()
    gc.collect()

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for i in tqdm_notebook(range(EPOCH, 50, 1)):


HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, targets in tqdm_notebook(WD_Train_Loader):


HBox(children=(FloatProgress(value=0.0, max=338.0), HTML(value='')))

Some weird shape error due to blank images, but handeled
Some weird shape error due to blank images, but handeled
Some weird shape error due to blank images, but handeled
Some weird shape error due to blank images, but handeled
Some weird shape error due to blank images, but handeled
Some weird shape error due to blank images, but handeled
Some weird shape error due to blank images, but handeled



Exception ignored in: <generator object tqdm.__iter__ at 0x0000018F30198200>
Traceback (most recent call last):
  File "D:\Anaconda3\envs\pytorch_env\lib\site-packages\tqdm\std.py", line 1182, in __iter__
    self.close()
  File "D:\Anaconda3\envs\pytorch_env\lib\site-packages\tqdm\notebook.py", line 241, in close
    super(tqdm_notebook, self).close(*args, **kwargs)
  File "D:\Anaconda3\envs\pytorch_env\lib\site-packages\tqdm\std.py", line 1293, in close
    self.display(pos=0)
KeyboardInterrupt: 


KeyboardInterrupt: 