In [None]:
%reset
import os
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
import random
import cv2
import models.layers
import addons.trees as trees
from models.vision import HTCNN, LeNet5

def loadData(data_path, data_file):
    output = []
    with open(data_file, 'r') as f:
        for ln in f:
            fields = ln.rstrip('\n').split(',')
            output.append([os.path.join(data_path,fields[0]), int(fields[1])])
    return output
            
def loadInBatch(ds, r = 0, batchsize = 16, shuffle=False):
    output_data = None
    aux_labels = []
    fine_labels = None
    i = 0
    ndata = len(ds)
    hasDone = False
    while i<batchsize:
        data_rec = ds[r][0]
        img_data = cv2.imread(data_rec)
        base_label = ds[r][1]
        data_blob = torch.tensor(img_data).float().permute(2,0,1)
        if output_data is None:
            output_data = torch.zeros(batchsize, img_data.shape[2], img_data.shape[0], img_data.shape[1], device=device)
        output_data[i, ...] = data_blob
        if aux_labels == []:
            j = 0
            for lv in lookup_lv_list:
                output_label = torch.zeros(batchsize, coarst_dims[j], device=device)
                output_label.require_grad = False
                aux_labels.append(output_label)
                j += 1
        if fine_labels is None:
            fine_labels = torch.zeros(batchsize, n_fine, device=device)
        j = 0
        for lv in lookup_lv_list:
            up_cls = lookupParent(classTree, base_label, lv)
            aux_labels[j].data[i, up_cls] = 1.0
            j += 1
        fine_labels.data[i, base_label] = 1.0
        r += 1
        if r >= ndata:
            r = 0
            hasDone = True
            if shuffle:
                random.shuffle(ds)
        i += 1
        
    output_data.require_grad = False
    fine_labels.require_grad = False
    return output_data, aux_labels, fine_labels, r, hasDone
        
def lookupParent(tree, fine_node, upper_lv=1):
    return tree[fine_node][upper_lv-1]

def accumulateList(list1, list2):
    output = []
    for i in range(len(list1)):
        output.append((list1[i] + list2[i]) * 0.5)
    return output

def computeBatchAccuracy(pred, expected):
    output = []
    n_output = len(pred)
    n_batch = pred[0].shape[0]
    for i in range(n_output):
        local_result = 0.0
        for j in range(n_batch):
            cls_pred = pred[i][j].argmax()
            cls_exp = expected[i][j,...].argmax()
            #print((cls_pred, cls_exp))
            if cls_pred == cls_exp:
                local_result += 1.0
        local_result /= n_batch
        output.append(local_result)
    return output

def computeAccuracy(dataset, model, batchsize = 1, withAux = False):
    data_count = len(dataset)
    ptr = 0
    batch_len = int(np.floor(float(data_count)/batchsize))
    batch_elen = int(np.ceil(float(data_count)/batchsize))
    output = []
    aux_output = []
    for i in range(batch_len):
        batch_data, expected_aux, expected_fine, ptr, _ = loadInBatch(dataset, ptr, batchsize)
        pred_final, pred_aux = model(batch_data)
        batch_result = computeBatchAccuracy([pred_final], [expected_fine])
        if output == []:
            output = batch_result
        else:
            for j in range(len(output)):
                output[j] += batch_result[j]
        if withAux:
            batch_aux_result = computeBatchAccuracy(pred_aux, expected_aux + [expected_fine])
            if aux_output == []:
                aux_output = batch_aux_result
            else:
                for j in range(len(aux_output)):
                    aux_output[j] += batch_aux_result[j]
    if batchsize!=1 and batch_len != batch_elen:
        tmp_batchsize = data_count - ptr
        batch_data, expected_aux, expected_fine, ptr, _ = loadInBatch(dataset, ptr, tmp_batchsize)
        pred_final, pred_aux = model(batch_data)
        batch_result = computeBatchAccuracy([pred_final], [expected_fine])
        for j in range(len(output)):
            output[j] += batch_result[j]
            output[j] /= batch_len + 1
        if withAux:
            batch_aux_result = computeBatchAccuracy(pred_aux, expected_aux + [expected_fine])
            for j in range(len(aux_output)):
                aux_output[j] /= batch_len + 1
    else:
        for j in range(len(output)):
            output[j] /= data_count
        if withAux:
            for j in range(len(aux_output)):
                aux_output[j] /= data_count
        
    return output, aux_output

def train(trainset, valset, label_file, output_path, output_fname, 
          start_lr=0.1, lr_discount=0.1, lr_steps=[], epoch=30,
          train_batch = 16, val_batch = 16, val_at = 10,
          checkpoint = None, jud_at = -1):
    
    best_v_result = 0.0
    model = HTCNN(label_file, with_aux = True, with_fc = True, backbone=backbone,
              isCuda=True).cuda()
    
    
    #for name, param in model.named_parameters():
    #    if param.requires_grad:
    #        print(name)
    
    output_filepath = os.path.join(output_path, output_fname)

    if checkpoint is not None and os.path.isfile(checkpoint):
        model.load_state_dict(torch.load(checkpoint))
        print('Loaded from checkpoint %s'%checkpoint)
    
    #sample, _, _, _, _ = loadInBatch(trainset, batchsize = 1)
    #writer.add_graph(model, sample)
    #writer.close()
    
    v_result = 0
    
    backbone.eval()
    model.eval()
    with torch.no_grad():
        val_result, aux_val_result = computeAccuracy(valset, model, val_batch, withAux=True)
        v_result = val_result[0]
        print('Validation Accuracy: %f'%v_result)
        #print(aux_val_result)
        best_v_result = v_result
    
    lr = start_lr
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    
    
    # create losses
    losses = []
    aux_loss_names = []
    aux_val_names = []
    final_loss = nn.SoftMarginLoss()
    for lv in lookup_lv_list:
        losses.append(nn.SoftMarginLoss())
        aux_loss_names.append('Coarst %d loss'%lv)
        aux_val_names.append('Level %d accuracy'%lv)
    losses.append(nn.SoftMarginLoss())
    aux_loss_names.append('Fine loss')
    aux_val_names.append('Fine accuracy')
    n_aux = len(losses) - 1
    aux_accuracy = {}
    
    for i in range(epoch):
        # training phase
        backbone.train()
        model.train()
        ptr = 0
        hasFinishEpoch = False
        epoch_result = []
        epoch_aux_losses_v = []
        epoch_loss_v = 0
        iter_c = 0
        while not hasFinishEpoch:
            optimizer.zero_grad()
            
            batch_input, gt_aux, gt_final, ptr, hasFinishEpoch = loadInBatch(trainset, ptr, train_batch, shuffle=True)
            pred_final, pred_aux = model(batch_input)
            
            iloss = 0
            total_loss = final_loss(pred_final, gt_final)
            for i_aux in range(n_aux):
                aux_loss = losses[i_aux](pred_aux[i_aux], gt_aux[i_aux])
                total_loss += aux_loss
                aux_loss_v = aux_loss.item()
                if epoch_aux_losses_v == []:
                    epoch_aux_losses_v.append(aux_loss_v)
                else:
                    epoch_aux_losses_v[iloss] += aux_loss_v
                iloss += 1
            fine_loss = losses[-1](pred_aux[-1], gt_final)
            total_loss += fine_loss
            fine_loss_v = fine_loss.item()
            if len(epoch_aux_losses_v) <= iloss:
                epoch_aux_losses_v.append(fine_loss_v)
            else:
                epoch_aux_losses_v[iloss] += fine_loss_v
            # compute gradients
            total_loss.backward()
            
            # update weights
            optimizer.step()
            
            if iter_c == 0:
                epoch_loss_v = total_loss.item()
            else:
                epoch_loss_v += total_loss.item()
            
            if epoch_loss_v == 0:
                epoch_loss_v = total_loss
            
            result = computeBatchAccuracy([pred_final],[gt_final])
            if epoch_result == []:
                epoch_result = result
            else:
                epoch_result = accumulateList(epoch_result, result)
            iter_c += 1
        
        
        #print('Training Loss:', end='')
        plot_loss = {}
        for iloss in range(n_aux+1):
            epoch_aux_losses_v[iloss] /= iter_c
            plot_loss[aux_loss_names[iloss]] = epoch_aux_losses_v[iloss]
            #print('%s: %f, '%(aux_loss_names[iloss], epoch_aux_losses_v[iloss]), end='')
        epoch_loss_v /= iter_c
        plot_loss['total loss'] = epoch_loss_v
        writer.add_scalars('training loss', 
                          plot_loss,
                          i)
        print(plot_loss)
        #print('Fine loss: %f'%epoch_loss_v)
        
        # validation phase
        if i % val_at == 0:
            print('Validating...')
            backbone.eval()
            model.eval()
            with torch.no_grad():
                val_result, aux_val_result = computeAccuracy(valset, model, val_batch, withAux=True)
                for iacc in range(len(aux_val_names)):
                    aux_accuracy[aux_val_names[iacc]] = aux_val_result[iacc]
                v_result = val_result[0]
                print('Validation Accuracy: %f'%v_result)
                print(aux_accuracy)
                if v_result > best_v_result:
                    print('Best model found and saving it.')
                    torch.save(model.state_dict(), output_filepath)
                    best_v_result = v_result
        if i in lr_steps:
            olr = lr
            lr *= lr_discount
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            print('learning rate has been discounted from %f to %f'%(olr, lr))
        writer.add_scalars('Accuracy', 
                          aux_accuracy,
                          i)
            
    print('Model has been trained.')
    model = None
    
def main():
    
    checkpoint_path = os.path.join(model_path, model_fname)
    
    train_set = loadData(ds_root_path, training_file)
    val_set = loadData(ds_root_path, val_file)
    print('Training set has been buffered.')
    train(train_set, val_set, label_filepath,
          output_path = model_path, output_fname = model_fname, 
          epoch=300, val_at=5, lr_steps=[100, 200],
         train_batch=128, val_batch=64, checkpoint=checkpoint_path)
    
    
    
    #final_y, aux_y = nn(x_)
    #print('--------Final Output-----------')
    #print(final_y)
    #print(final_y.argmax())
    #print(final_y)
    #print('--------Partial Output---------')
    #print(aux_y)
    #print(aux_y[-1].argmax())
    
    #nn.eval()
    #with torch.no_grad():
    #    y = nn(x_)
    #    print(y)
    
    backbone = None
    torch.cuda.empty_cache()
    writer.close()
    print('Done')

if __name__ == '__main__':
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    label_filepath = '/datasets/vision/cifar100_clean/tree.txt'
    #label_filepath = '/datasets/dummy/set1/tree.txt'
    classTree, n_coarst, coarst_dims = trees.build_itree(label_filepath)
    lookup_lv_list = [i+1 for i in range(n_coarst)]
    n_fine = len(list(classTree.keys()))
    
    ds_root_path = '/datasets/vision/cifar100_clean'
    training_file = '/datasets/vision/cifar100_clean/train.txt'
    val_file = '/datasets/vision/cifar100_clean/val.txt'
    test_file = '/datasets/vision/cifar100_clean/val.txt'
    
    model_path = '/models/cifar100_htcnn_1'
    if not os.path.isdir(model_path):
        os.mkdir(model_path)
    model_fname = 'model.pth'
    
    backbone = LeNet5(n_classes=n_fine).cuda()
    
    writer = SummaryWriter(log_dir = '../training', purge_step = 0,
                          flush_secs = 5)
    
    main()

Once deleted, variables cannot be recovered. Proceed (y/[n])? y
Training set has been buffered.
Loaded from checkpoint /models/cifar100_htcnn_1/model.pth
Validation Accuracy: 0.015924
{'Coarst 1 loss': 0.6918770869064819, 'Fine loss': 0.6930978415567247, 'total loss': 2.0781195133238497}
Validating...
Validation Accuracy: 0.018710
{'Level 1 accuracy': 0.07056130573248408, 'Fine accuracy': 0.011345541401273885}
Best model found and saving it.
{'Coarst 1 loss': 0.6918485879593188, 'Fine loss': 0.6930979802785322, 'total loss': 2.0780911018781345}
{'Coarst 1 loss': 0.6918256976415434, 'Fine loss': 0.6930981472020259, 'total loss': 2.078068351501699}
{'Coarst 1 loss': 0.6918037370647616, 'Fine loss': 0.6930982985764819, 'total loss': 2.0780464919936628}
{'Coarst 1 loss': 0.6917813866949447, 'Fine loss': 0.6930983632116976, 'total loss': 2.078024142233612}
{'Coarst 1 loss': 0.6917600182011304, 'Fine loss': 0.6930984560485995, 'total loss': 2.0780028333444425}
Validating...
Validation Accura

{'Coarst 1 loss': 0.6912481540914082, 'Fine loss': 0.6930984513229116, 'total loss': 2.0774895730225937}
{'Coarst 1 loss': 0.6912468475149111, 'Fine loss': 0.693098458944989, 'total loss': 2.0774882614155255}
{'Coarst 1 loss': 0.6912390647641838, 'Fine loss': 0.6930984406520033, 'total loss': 2.077480446042307}
{'Coarst 1 loss': 0.6912300938840412, 'Fine loss': 0.6930984286091212, 'total loss': 2.077471453210582}
{'Coarst 1 loss': 0.691237036681846, 'Fine loss': 0.6930984119929926, 'total loss': 2.0774783886911923}
Validating...
Validation Accuracy: 0.021895
{'Level 1 accuracy': 0.1252985668789809, 'Fine accuracy': 0.009952229299363057}
Best model found and saving it.
{'Coarst 1 loss': 0.6912214606619247, 'Fine loss': 0.6930983857730465, 'total loss': 2.077462757944756}
{'Coarst 1 loss': 0.6912182101508235, 'Fine loss': 0.6930984205297192, 'total loss': 2.0774595286230295}
{'Coarst 1 loss': 0.6912119842856131, 'Fine loss': 0.6930984609267291, 'total loss': 2.0774533309595054}
{'Coarst 