In [2]:
import torch
import torchvision
import math
import pandas as pd
import numpy as np
import gc

from tqdm import tqdm_notebook
from torch.utils.tensorboard import SummaryWriter


from GWD_data import WheatDataset
from metric import calculate_image_precision
from utils import collate_fn, plot_boxes, get_model_name, format_prediction_string, remove_blanks, is_contain_blanks
from config import config


In [3]:
conf = config()
torch.random.manual_seed(5)
writer = SummaryWriter("runs/AugTraining_1") #


WD_Train = WheatDataset(conf)
WD_Valid = WheatDataset(conf, train=False)
WD_Train_Loader = torch.utils.data.DataLoader(WD_Train, batch_size=conf.BATH_SIZE, shuffle=True, collate_fn=collate_fn)
WD_Valid_Loader = torch.utils.data.DataLoader(WD_Valid, batch_size=conf.BATH_SIZE, shuffle=True, collate_fn=collate_fn)

In [4]:
def get_model(num_classes=2):
    print('New Model ')
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True, progress=True, pretrained_backbone=True)
    model.roi_heads.box_predictor.cls_score = torch.nn.Linear(model.roi_heads.box_predictor.cls_score.in_features, num_classes, bias=True)
    model.roi_heads.box_predictor.bbox_pred = torch.nn.Linear(model.roi_heads.box_predictor.bbox_pred.in_features, num_classes*4, bias=True)
    return model

In [5]:
model_path = 'model/model_aug/GWD_EPOCH_89_SCORE_0.0000_LOSS_184.3004.pt'
GWD_Model = torch.load(model_path) if model_path else get_model()
GWD_Model.to(conf.DEVICE)
EPOCH = int(model_path.split('/')[-1].split('_')[2])+1 if model_path else 0

params = [p for p in GWD_Model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.0005,momentum=0.9, weight_decay=0.0005)

In [6]:
_iter=len(WD_Train_Loader)*(EPOCH)

for i in tqdm_notebook(range(EPOCH, 100, 1)):
    
    _epoch_loss=0
    _ = GWD_Model.train()
    for images, targets in tqdm_notebook(WD_Train_Loader):
        
        #if not is_contain_blanks(targets):
        try:
            images = [torch.tensor(image, dtype = torch.float32).to(conf.DEVICE) for image in images]
            targets = [{k: torch.tensor(v).to(conf.DEVICE) for k, v in target.items()} for target in targets]

            loss_dictionary = GWD_Model(images,targets)
            summed_loss = sum(loss for loss in loss_dictionary.values())
            summed_loss_value = summed_loss.item()


            if math.isfinite(summed_loss_value):
                optimizer.zero_grad()
                summed_loss.backward()
                optimizer.step()
                _epoch_loss+=summed_loss_value

            else:
                print('Loss is undefined:',summed_loss_value,'   skipping BackProp for step no:',_iter)
                print(loss_dictionary)


            writer.add_scalar('Running Loss/Summed', summed_loss_value, _iter)
            writer.add_scalar('Running Loss/Classifier', loss_dictionary['loss_classifier'].item(), _iter)
            writer.add_scalar('Running Loss/Box_Regress', loss_dictionary['loss_box_reg'].item(), _iter)
            writer.add_scalar('Running Loss/Objectness', loss_dictionary['loss_objectness'].item(), _iter)
            writer.add_scalar('Running Loss/RPN_Box_regress', loss_dictionary['loss_rpn_box_reg'].item(), _iter)
            
        except:
            print('Some weird shape error due to blank images, but handeled')
        
        _iter+=1
        
    torch.save(GWD_Model, 'model/model_aug/'+get_model_name(i, 0, _epoch_loss))
    torch.cuda.empty_cache()
    gc.collect()

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, max=338), HTML(value='')))

	nonzero(Tensor input, *, Tensor out)
Consider using one of the following signatures instead:
	nonzero(Tensor input, *, bool as_tuple)


Some weird shape error due to blank images, but handeled
Some weird shape error due to blank images, but handeled
Some weird shape error due to blank images, but handeled



HBox(children=(IntProgress(value=0, max=338), HTML(value='')))

Some weird shape error due to blank images, but handeled
Some weird shape error due to blank images, but handeled
Some weird shape error due to blank images, but handeled
Some weird shape error due to blank images, but handeled
Some weird shape error due to blank images, but handeled



HBox(children=(IntProgress(value=0, max=338), HTML(value='')))

Some weird shape error due to blank images, but handeled
Some weird shape error due to blank images, but handeled
Some weird shape error due to blank images, but handeled
Some weird shape error due to blank images, but handeled



HBox(children=(IntProgress(value=0, max=338), HTML(value='')))

Some weird shape error due to blank images, but handeled
Some weird shape error due to blank images, but handeled
Some weird shape error due to blank images, but handeled
Some weird shape error due to blank images, but handeled



HBox(children=(IntProgress(value=0, max=338), HTML(value='')))

Some weird shape error due to blank images, but handeled
Some weird shape error due to blank images, but handeled
Some weird shape error due to blank images, but handeled



HBox(children=(IntProgress(value=0, max=338), HTML(value='')))

Some weird shape error due to blank images, but handeled
Some weird shape error due to blank images, but handeled
Some weird shape error due to blank images, but handeled



HBox(children=(IntProgress(value=0, max=338), HTML(value='')))

Some weird shape error due to blank images, but handeled
Some weird shape error due to blank images, but handeled
Some weird shape error due to blank images, but handeled



HBox(children=(IntProgress(value=0, max=338), HTML(value='')))

Some weird shape error due to blank images, but handeled
Some weird shape error due to blank images, but handeled
Some weird shape error due to blank images, but handeled
Some weird shape error due to blank images, but handeled



HBox(children=(IntProgress(value=0, max=338), HTML(value='')))

Some weird shape error due to blank images, but handeled
Some weird shape error due to blank images, but handeled
Some weird shape error due to blank images, but handeled
Some weird shape error due to blank images, but handeled
Some weird shape error due to blank images, but handeled
Some weird shape error due to blank images, but handeled



HBox(children=(IntProgress(value=0, max=338), HTML(value='')))

Some weird shape error due to blank images, but handeled
Some weird shape error due to blank images, but handeled


