In [None]:
'''
pip install pyyaml
pip install easydict
''';

In [None]:
import os, sys
import yaml
from pathlib import Path

from easydict import EasyDict as edict

import numpy as np
import torch
from torch.cuda import amp
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
from torch.utils.data.dataloader import DataLoader

from models import DETR, SetCriterion
from utils.dataset import collateFunction, COCODataset
from utils.misc import MetricsLogger, saveArguments, logMetrics, cast2Float

from tqdm.notebook import tqdm, trange

In [None]:
CURRENT_PATH = os.path.join(os.getcwd())
BASE_PATH = Path(CURRENT_PATH).parent
CONFIG = os.path.join(CURRENT_PATH, 'config.yaml')

In [None]:
BASE_PATH


In [None]:
def parse_config():
    with open(CONFIG, 'r') as stream:
        try:
            return yaml.safe_load(stream)
        except yaml.YAMLError as exc:
            raise ValueError("Failed to parse config requried")

In [None]:
args = edict(parse_config())

In [None]:
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
saveArguments(args, args.taskName)
torch.manual_seed(1337)
device = torch.device(args.device)

In [None]:
train_dir = os.path.join(BASE_PATH, 'data/coco_mini/trainset')
ann_dir = os.path.join(BASE_PATH, 'data/coco_mini/instances_minitrain2017.json')

dataset = COCODataset(train_dir,
                      ann_dir,
                      args.targetHeight,
                      args.targetWidth,
                      args.numClass)
dataloader = DataLoader(dataset,
                        batch_size=args.batchSize,
                        shuffle=True,
                        collate_fn=collateFunction,
                        pin_memory=True,
                        num_workers=args.numWorkers)

In [None]:
model = DETR(args).to(device)
criterion = SetCriterion(args).to(device)

In [None]:
if args.weightDir and os.path.exists(args.weightDir):
    print(f'loading pre-trained weights from {args.weightDir}')
    model.load_state_dict(torch.load(args.weightDir, map_location=device))

# multi-GPU training
if args.multi:
    model = torch.nn.DataParallel(model)
        
# separate learning rate     
paramDicts = [
        {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
        {
            "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
            "lr": args.lrBackbone,
        },
    ]


In [None]:
optimizer = AdamW(paramDicts, args.lr, weight_decay=args.weightDecay)
lrScheduler = StepLR(optimizer, args.lrDrop)
prevBestLoss = np.inf
batches = len(dataloader)
logger = MetricsLogger()

In [None]:
losses = []
def train(epoch):
    model.train()
    criterion.train()
    scaler = amp.GradScaler()
    with tqdm(dataloader, unit='batch') as tepoch:
        for (x, y) in tepoch:
            tepoch.set_description(f'Train epoch {epoch}')
            x = x.to(device)
            y = [{k: v.to(device) for k, v in t.items()} for t in y]

            if args.amp:
                with amp.autocast():
                    out = model(x)
                out = cast2Float(out)
            else:
                out = model(x)

            metrics = criterion(out, y)
        
            loss = sum(v for k, v in metrics.items() if 'loss' in k)
            losses.append(loss.cpu().item())

            # MARK: - backpropagation
            optimizer.zero_grad()
            if args.amp:
                scaler.scale(loss).backward()
                if args.clipMaxNorm > 0:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.clipMaxNorm)
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                if args.clipMaxNorm > 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.clipMaxNorm)
                optimizer.step()
            tepoch.set_postfix(loss=loss.cpu().item()) 
            
        lrScheduler.step()

        avgLoss = np.mean(losses)

        if avgLoss < prevBestLoss:
            print('[+] Loss improved from {:.8f} to {:.8f}, saving model...'.format(prevBestLoss, avgLoss))
            if not os.path.exists(args.outputDir):
                os.mkdir(args.outputDir)

            try:
                stateDict = model.module.state_dict()
            except AttributeError:
                stateDict = model.state_dict()
            torch.save(stateDict, f'{args.outputDir}/{args.taskName}.pt')
            prevBestLoss = avgLoss

In [None]:
for epoch in trange(args.epochs):
    train(epoch)