In [1]:
import sys
sys.path.append('..')

import os
import torch
from tqdm import tqdm
from copy import deepcopy
from torch.utils.tensorboard import SummaryWriter

from src.utils.get_model_and_data import get_model_and_data
from src.utils.collate_fn_coco import collate_fn_coco

In [2]:
# checkpoint_path = '../checkpoint/boxclip-finetune-coslr'
checkpoint_path = '../checkpoint/boxclip-one-cap/'
checkpoint_name = 'checkpoint-epoch{}.pth.tar'

In [3]:
parameters = {'device': 'cuda'}
model, datasets = get_model_and_data(parameters)

loading annotations into memory...


KeyboardInterrupt: 

In [None]:
def eval_loss(dataloader, model, device):
    dict_loss = {}
    model.eval()
    with torch.no_grad():
        for i, batch in enumerate(tqdm(dataloader)):
            for k, v in batch.items():
                if torch.is_tensor(batch[k]): batch[k] = batch[k].to(device)
                
            batch.update(model(batch))
            _, losses = model.compute_loss(batch)
            if i == 0:
                dict_loss = deepcopy(losses)
            else:
                for k in dict_loss.keys():
                    dict_loss[k] += losses[k]
        for k in dict_loss.keys():
            dict_loss[k] /= len(dataloader)
    
    dict_loss = {'val_'+k: v for k, v in dict_loss.items()}
    return dict_loss

In [None]:
# val_losses = eval_loss(val_loader, model, 'cuda')
iter_per_eps = len(datasets['train']) // 64
if iter_per_eps * 64 < len(datasets['train']): iter_per_eps += 1
print(f'iter per epoch: {iter_per_eps}')

writer = SummaryWriter(log_dir=checkpoint_path)
val_loader = torch.utils.data.DataLoader(datasets['val'], batch_size=20, collate_fn=collate_fn_coco)

for i in range(0, 105, 5):
    try:
        checkpoint = torch.load(os.path.join(checkpoint_path, checkpoint_name.format(i)))
    except:
        print(f'{checkpoint_name.format(i)} not found.')
        continue
        
    model.load_state_dict(checkpoint['model'])
    print(f'checkpoint {checkpoint_name.format(i)} loaded.')
    
    val_losses = eval_loss(val_loader, model, 'cuda')
#     print(val_losses)
#     assert False
    writer.add_scalars("Loss/Iters", val_losses, i * iter_per_eps)