In [None]:
import argparse
import logging
import os
import random
import sys
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tensorboardX import SummaryWriter
from torch.nn.modules.loss import CrossEntropyLoss
from torch.utils.data import DataLoader
from tqdm import tqdm
from utils import DiceLoss
from torchvision import transforms
import gc


In [None]:
from datasets.dataset_HuBMAP import HuBMAP_dataset, RandomGenerator
logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO,
                        format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
logging.info(str(args))
base_lr = args.base_lr
num_classes = args.num_classes
batch_size = args.batch_size * args.n_gpu

db_test = HuBMAP_dataset(base_dir=args.root_path, list_dir=args.list_dir, split="test",
                            transform=transforms.Compose(
                                [RandomGenerator(output_size=[args.img_size, args.img_size])]))

print("The length of train set is: {}".format(len(db_test)))

def worker_init_fn(worker_id):
    random.seed(args.seed + worker_id)

testloader = DataLoader(db_test, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True,
                             worker_init_fn=worker_init_fn)  

 

In [None]:
if args.n_gpu > 1:
    model = nn.DataParallel(model)

ce_loss = CrossEntropyLoss()
dice_loss = DiceLoss(num_classes)
writer = SummaryWriter(snapshot_path + '/log')
logging.info("{} iterations per epoch. {} max iterations ".format(len(testloader), max_iterations))
iterator = tqdm(range(max_epoch), ncols=70)


In [None]:
total_test_loss = 0
total_test_dice_loss = 0
iter_num = 0
for i_batch, sampled_batch in enumerate(testloader):
    model.eval()
    image_batch, label_batch = sampled_batch['image'], sampled_batch['label']
    image_batch, label_batch = image_batch.cuda(), label_batch.cuda()

    outputs = model(image_batch)
    loss_ce = ce_loss(outputs, label_batch[:].long())
    loss_dice = dice_loss(outputs, label_batch, softmax=True)
    loss = 0.5 * loss_ce + 0.5 * loss_dice
    optimizer.zero_grad()
    
    ###
    total_test_loss += loss.item()
    total_test_dice_loss += loss_dice.item()
    ###
    
    iter_num = iter_num + 1

avg_test_loss = total_test_loss/iter_num

avg_test_loss_dice = total_test_dice_loss/iter_num

writer.add_scalar('info/avg_test_loss', avg_test_loss)
writer.add_scalar('info/avg_test_loss_dice', avg_test_loss_dice)

logging.info('test_loss : %f, test_loss_dice: %f' % ( avg_test_loss, avg_test_loss_dice))



<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=c5aa932a-5c34-48e2-be33-10614017587c' target="_blank">
 </img>
Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>