## classifier_testing

In [37]:
from networks import *
from options import *

import torch.utils.data
import os.path
import torch
import torchvision.transforms as transforms
import torch.utils.data as data
import os
from pathlib import Path
from functools import reduce
import operator
import random
import cv2
import numpy as np
import argparse

class BaseOptions():
    def __init__(self):
        self.parser = argparse.ArgumentParser()
        self.initialized = False

    def initialize(self):    
        # experiment specifics
        self.parser.add_argument('--name', type=str, default='label2city', help='name of the experiment. It decides where to store samples and models')        
        self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0  0,1,2, 0,2. use -1 for CPU')                       
        self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
        self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')        
        self.parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator')

        # input/output sizes       
        self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
        self.parser.add_argument('--loadSize', type=int, default=1024, help='scale images to this size')
        self.parser.add_argument('--fineSize', type=int, default=512, help='then crop to this size')
        self.parser.add_argument('--label_nc', type=int, default=35, help='# of input image channels')
        self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
        self.parser.add_argument('--use_PIL', action='store_true', help='if true, uses the PIL library for data loading, openCV by default')
        self.parser.add_argument('--input_mask_fill', '--fill', dest='fill', type=str, default='W&B', help='fill to use in the input images for the artificialiy created masks, \
                                  options: "B&W" (salt&peper - default) | "W" (white) | "B" (black) | "G" (grey)')
      
        # for setting inputs
        self.parser.add_argument('--dataroot', type=str, default='/blanca/training_datasets/pix2pix/images_target_clean_classified') 
        self.parser.add_argument('--dataset_list', action='store', type=str, nargs='*', default='images_target_clean_classified video_target_clean_classified') 
        self.parser.add_argument('--resize_or_crop', type=str, default='scale_width', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
        self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')        
        self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation') 
        self.parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data')                
        self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')

        # for displays
        self.parser.add_argument('--display_winsize', type=int, default=512,  help='display window size')
        self.parser.add_argument('--tf_log', action='store_true', help='if specified, use tensorboard logging. Requires tensorflow installed')
        
        ## visdom
        self.parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')
        self.parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')

        # for generator
        self.parser.add_argument('--netG', type=str, default='global', help='selects model to use for netG')
        self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
        self.parser.add_argument('--n_downsample_global', type=int, default=4, help='number of downsampling layers in netG') 
        self.parser.add_argument('--n_blocks_global', type=int, default=9, help='number of residual blocks in the global generator network')
        self.parser.add_argument('--n_blocks_local', type=int, default=3, help='number of residual blocks in the local enhancer network')
        self.parser.add_argument('--n_local_enhancers', type=int, default=1, help='number of local enhancers to use')        
        self.parser.add_argument('--niter_fix_global', type=int, default=0, help='number of epochs that we only train the outmost local enhancer')        

        # for instance-wise features
        self.parser.add_argument('--no_instance', action='store_true', help='if specified, do *not* add instance map as input')        
        self.parser.add_argument('--instance_feat', action='store_true', help='if specified, add encoded instance features as input')
        self.parser.add_argument('--label_feat', action='store_true', help='if specified, add encoded label features as input')        
        self.parser.add_argument('--feat_num', type=int, default=3, help='vector length for encoded features')        
        self.parser.add_argument('--load_features', action='store_true', help='if specified, load precomputed feature maps')
        self.parser.add_argument('--n_downsample_E', type=int, default=4, help='# of downsampling layers in encoder') 
        self.parser.add_argument('--nef', type=int, default=16, help='# of encoder filters in the first conv layer')        
        self.parser.add_argument('--n_clusters', type=int, default=10, help='number of clusters for features')
        
        self.initialized = True

    def parse(self, save=True):
        if not self.initialized:
            self.initialize()
        self.opt = self.parser.parse_args()
        self.opt.isTrain = self.isTrain   # train or test

        str_ids = self.opt.gpu_ids.split(',')
        self.opt.gpu_ids = []
        for str_id in str_ids:
            id = int(str_id)
            if id >= 0: self.opt.gpu_ids.append(id)
        
        # set gpu ids
        if len(self.opt.gpu_ids) > 0: torch.cuda.set_device(self.opt.gpu_ids[0])

        args = vars(self.opt)

        print('------------ Options -------------')
        for k, v in sorted(args.items()): print('%s: %s' % (str(k), str(v)))
        print('-------------- End ----------------')

        # save to the disk        
        expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
        util.mkdirs(expr_dir)
        if save and not self.opt.continue_train:
            file_name = os.path.join(expr_dir, 'opt.txt')
            with open(file_name, 'wt') as opt_file:
                opt_file.write('------------ Options -------------\n')
                for k, v in sorted(args.items()): opt_file.write('%s: %s\n' % (str(k), str(v)))
                opt_file.write('-------------- End ----------------\n')
        return self.opt

class TrainOptions(BaseOptions):
    def initialize(self):
        BaseOptions.initialize(self)
        # for displays
        self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen')
        self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
        self.parser.add_argument('--save_latest_freq', type=int, default=10000, help='frequency of saving the latest results')
        self.parser.add_argument('--save_epoch_freq', type=int, default=10, help='frequency of saving checkpoints at the end of epochs')        
        self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
        self.parser.add_argument('--debug', action='store_true', help='only do one epoch and displays at each iteration')
        ## visdom:
        self.parser.add_argument('--display_single_pane_ncols', type=int, default=0, help='if positive, display all images in a single visdom web panel with certain number of images per row.')

        # for training
        self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
        self.parser.add_argument('--load_pretrain', type=str, default='', help='load the pretrained model from the specified location')
        self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
        self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
        self.parser.add_argument('--niter', type=int, default=50, help='# of iter at starting learning rate')
        self.parser.add_argument('--niter_decay', type=int, default=50, help='# of iter to linearly decay learning rate to zero')
        self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
        self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')

        # for discriminators        
        self.parser.add_argument('--num_D', type=int, default=2, help='number of discriminators to use')
        self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers')
        self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')    
        self.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss')                
        self.parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss')
        self.parser.add_argument('--no_vgg_loss', action='store_true', help='if specified, do *not* use VGG feature matching loss')        
        self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN')
        self.parser.add_argument('--pool_size', type=int, default=0, help='the size of image buffer that stores previously generated images')
        
        # for the generator:
        self.parser.add_argument('--lambda_L1', type=float, default=0, help='weight for the L1 loss')
        self.parser.add_argument('--lambda_G_cos1', type=float, default=0, help='weight for the L1 loss')
        self.parser.add_argument('--lambda_G_cos2', type=float, default=0, help='weight for the L1 loss')
        self.parser.add_argument('--lambda_G_cos1_z', type=float, default=0, help='weight for the L1 loss')
        self.parser.add_argument('--lambda_G_cos2_z', type=float, default=0, help='weight for the L1 loss')
        self.parser.add_argument('--lambda_G_KL_fake', type=float, default=0, help='weight for the KL loss')
        self.parser.add_argument('--lambda_E_KL_real', type=float, default=0, help='weight for the KL_real loss')
        self.parser.add_argument('--lambda_E_KL_fake', type=float, default=0, help='weight for the KL_fake loss')
        
        self.parser.add_argument('--lambda_G_class', type=bool, default=10, help='weight for the classsification loss')
        
        # for the classifier
        self.parser.add_argument('--class_nc', type=int, default=4, help='weight for the L1 loss')
        self.parser.add_argument('--n_layers_C', type=int, default=3, help='weight for the L1 loss')
        self.parser.add_argument('--num_C', type=int, default=1, help='weight for the L1 loss')
        
        self.isTrain = True

def read_image_OpenCV(path, opt, is_pair=False, target_size=None):
        im = cv2.imread(path, cv2.IMREAD_UNCHANGED)
        if is_pair: im = cv2.resize(im, (opt.loadSize * 2, opt.loadSize))
        elif target_size: im = cv2.resize(im, target_size)
        else: im = cv2.resize(im, (opt.loadSize, opt.loadSize))
        return im

def fill_gaps(im, opt,
              fill_input_with=None,
              add_artificial=False,
              only_artificial=False,
              cm_p = '/blanca/resources/contour_mask/contour_mask.png'):
    
    if not fill_input_with: fill_input_with = opt.fill
    
    # creating fills for filling gaps
    n_fill_WB = cv2.randu(np.zeros(im.shape[:2]), 0, 255)
    n_fill_B = np.zeros(im.shape[:2]) * 255
    n_fill_W = np.ones(im.shape[:2]) * 255
    n_fill_G = np.ones(im.shape[:2]) * 127
    
    which_fill = dict(zip(['W', 'B', 'W&B', 'G'], [n_fill_W, n_fill_B, n_fill_WB, n_fill_G]))

    # read contour template mask
    # im_cm = cv2.imread(cm_p, cv2.IMREAD_UNCHANGED)
    im_cm = read_image_OpenCV(cm_p, opt, is_pair=False)
    assert im_cm is not None, 'Make sure there is a "contour_mask.png" file in this folder'
    mask = im_cm == 255 # contour mask: internal 
    
    new_alpha = im[:,:,3] != 0
    new_alpha[im_cm == 0] = 1 # ensuring we exclude corners what != txt. map
    im[:,:,3] = new_alpha * 255 # applying
    
    if not only_artificial:
        for i in range(im.shape[2] - 1):
            if fill_input_with=='average':
                mask_average = im[:,:,3] != 0
                # print('Filling gaps with %f instead of %f' %(np.mean(im[:,:,0]), np.mean(im[:,:,0][mask_average])))
                im[:,:,i][~new_alpha] = (np.ones(im.shape[:2]) * np.mean(im[:,:,i][mask_average]))[~new_alpha]    
            else: im[:,:,i][~new_alpha] = which_fill[fill_input_with][~new_alpha]
    
    if add_artificial and opt.phase != 'test':
        
        ## SELECT AN OCCLUSSION AND APPLY TO THE IMAGE
        gpath = Path('/blanca')
        rpath = gpath / 'resources/db_occlusions'
        rpath = [rpath]
        rpaths = reduce(operator.add, 
                      [list(j.glob('*')) for j in rpath],
                      [])
        
        random_idx = random.sample(range(len(rpaths)), len(rpaths))
        ix = random.sample(range(len(rpaths)), 1)[0]

        imr = read_image_OpenCV(str(rpaths[ix]), opt, is_pair=False)
        alpha_artifitial = imr
        alpha_artifitial[alpha_artifitial != 0] = 1
        im[:,:,3][im[:,:,3] != 0] = 1
        new_alpha_artifitial = alpha_artifitial * im[:,:,3]
        # setting the new alpha channel
        im[:,:,3] = new_alpha_artifitial * 255
        
        # filling
        for i in range(im.shape[2] - 1):
            im[:,:,i][new_alpha_artifitial == 0] = which_fill[fill_input_with][new_alpha_artifitial == 0]
    
    return im

def reduce_and_shuffle_dict_values_nested1level(d):
    
    flat = reduce(operator.add,
                      [reduce(operator.add, i.values(), []) for i in list(d.values())], 
                      [])

    [random.shuffle(flat) for i in range(int(1e2))]

    return flat

def create_dataset_from_dir2subdir(dir, class_label=None, nitems=None):
    
    """
    Create a list of paths from a nesteed dir with two levels, selecting nitems from each dir of the last level 
    """
    
    EXT_RECURSIVE = ['**/*.jpg', '**/*.JPG', '**/*.png', '**/*.ppm']
    from collections import OrderedDict
        
    path = Path(dir)
    id_names = [i.parts[-1] for i in list(path.glob('*')) if os.path.isdir(i)]
    
    n_items_per_last_level = nitems

    data_dict = OrderedDict({i: {} for i in sorted(id_names)})
    data_dict_nitems = OrderedDict({i: {} for i in sorted(id_names)})

    # INITIALISE
    for i in id_names:
        for j in os.listdir(path/i):
            data_dict[i][j] = None

    # FILLING
    import random
    random.seed()

    for i in data_dict.keys():
        for j in data_dict[i].keys():
            txt_pl = reduce(
                  operator.add,
                  [list((path/i/j).glob('**/*.isomap.png'))],
                  [])

            # DICT WITH ALL PATHS
            data_dict[i][j] = txt_pl

            # DICT WITH MAX(N) PATHS
            random_idx = random.sample(range(len(txt_pl)), min(len(txt_pl), n_items_per_last_level))
            txt_pl_nitems = [str(txt_pl[i]) for i in random_idx]

            data_dict_nitems[i][j] = txt_pl_nitems

    print('Total found IDs in path %s: %d' %(path, len(data_dict_nitems)), '.. and selected %d per ID' %n_items_per_last_level)
    
    data_list_n_shuffled = reduce_and_shuffle_dict_values_nested1level(data_dict_nitems)
    data_list_n_shuffled_labeled = [(i, class_label) for i in data_list_n_shuffled]
    return data_list_n_shuffled_labeled

def create_dataset_fromIDsubfolders(path, nitems=None):
    assert os.path.isdir(path), '%s is not a valid directory' % path
    images = create_dataset_from_dir2subdir(path, nitems)
    return images

def create_dataset_fromIDsubfolders_withLabel(path, classes=['bad_fit', 'good_fit'], nitems=None):

    assert os.path.isdir(path), '%s is not a valid directory' % path
    
    path = Path(path)
    class_paths = [i for i in path.glob('*') if os.path.isdir(i)]
    # there are no pairs for videos (f.d.m.)
    path_list = []
    for i in class_paths:
            class_label = classes.index(i.parts[-1])
            path_list += create_dataset_from_dir2subdir(i, class_label, nitems)
    print('shuffling...', end='')
    random.seed(1984)
    [random.shuffle(path_list) for i in range(int(1e2))]
    print('done')
    
    return path_list

def create_dataset_withLabel(path, classes=['bad_fit', 'good_fit'], return_pairs=None):
        path = Path(path)
        class_paths = [i for i in path.glob('*') if os.path.isdir(i)]
        pairs_path_list = []
        for i in class_paths:
                class_label = classes.index(i.parts[-1])
                pairs_path_list += create_dataset(i, class_label)
        
        print('shuffling...', end='')
        random.seed(1984)
        [random.shuffle(pairs_path_list) for i in range(int(1e2))]
        print('done')
        
        return pairs_path_list
        
def create_dataset(path, class_label=None, return_pairs=None):
        path = Path(path)
        ims_path_list = path.glob('*_m.png')
        
        pairs_path_list = []
        for i in ims_path_list:
                pair_name = i.parts[-1].split('_m.png')[0] + '.png'
                pair_path = str(list(path.glob(pair_name))[0])
                pair_path_mirror = str(i)
                pairs_path_list.append([(pair_path, class_label), (pair_path_mirror, class_label)])

        print('shuffling...', end='')
        random.seed(1984)
        [random.shuffle(pairs_path_list) for i in range(int(1e2))]
        print('done')
        
        if return_pairs==None: pairs_path_list = reduce(operator.add, pairs_path_list, [])
        return pairs_path_list


def apply_data_transforms(im, which, opt, nchannels):
        
    transform_list = []
    if which == 'target':
        transform_list += [
                transforms.Lambda(lambda x: fill_gaps(x, opt, fill_input_with='average')),
                transforms.Lambda(lambda x: x[:, :, :nchannels]),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]
                
    elif which == 'targetXC':
         transform_list += [
                transforms.Lambda(lambda x: x[:, :, :nchannels]),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
                ]
    
    elif which == 'input':        
        transform_list += [
            transforms.Lambda(lambda x: x.copy()),
            transforms.Lambda(lambda x: fill_gaps(x, opt, add_artificial=opt.isTrain)),
            transforms.Lambda(lambda x: x[:, :, :nchannels]),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]                                    
                                         
    return transforms.Compose(transform_list)(im)


class BaseDataset(data.Dataset):
    def __init__(self):
        super(BaseDataset, self).__init__()

    def name(self):
        return 'BaseDataset'

    def initialize(self, opt):
        pass


class AlignedDataset(BaseDataset):
        
    def initialize(self, opt):
        self.opt = opt
        self.root = opt.dataroot    
        self.dataset_list = opt.dataset_list

        self.target_paths = []
        if self.opt.isTrain:
            self.target_paths += create_dataset_withLabel(os.path.join(self.root, self.dataset_list[0])); print(len(self.target_paths))
            self.target_paths += create_dataset_fromIDsubfolders_withLabel(os.path.join(self.root, self.dataset_list[1]), nitems=2); print(len(self.target_paths))
        else:
            self.target_paths += create_dataset(os.path.join(self.root, self.dataset_list[0]))
  
        self.dataset_size = len(self.target_paths) 
      
    def __getitem__(self, index):                             
        
        target_tensor = target_label_tensor = inst_tensor = feat_tensor = 0
        
        input_nc = self.opt.label_nc if self.opt.label_nc != 0 else 3
        output_nc = self.opt.output_nc
        
        # read target image
        target_path = self.target_paths[index][0]
        target_label = self.target_paths[index][1]
        target_im = read_image_OpenCV(target_path, self.opt)
        target_im_resized = read_image_OpenCV(target_path, self.opt, target_size=(224, 224))
        
        # create input tensor first
        # if self.opt.isTrain:
        input_tensor = apply_data_transforms(target_im, 'input', self.opt, input_nc)
        
        # create output tensor
        if self.opt.isTrain: 
                target_tensor = apply_data_transforms(target_im, 'target', self.opt, output_nc)
                target4C_tensor = apply_data_transforms(target_im_resized, 'targetXC', self.opt, nchannels=4)
        
        # if target_label:
        target_label_tensor = torch.FloatTensor([target_label])
        
        input_dict = {'input': input_tensor, 'inst': inst_tensor, 
                      'target': target_tensor, 'feat': feat_tensor, 
                      'target4C': target4C_tensor,
                      'path': target_path, 'label': target_label_tensor}

        return input_dict

    def __len__(self):
        return len(self.target_paths)

    def name(self):
        return 'AlignedDataset'

def CreateDataset(opt):
    dataset = None
    dataset = AlignedDataset()
    dataset.initialize(opt)
    print("dataset [%s] was created" % (dataset.name()))
    return dataset

class BaseDataLoader():
    def __init__(self):
        pass
    
    def initialize(self, opt):
        self.opt = opt
        pass

    def load_data():
        return None

class CustomDatasetDataLoader(BaseDataLoader):
    def name(self):
        return 'CustomDatasetDataLoader'

    def initialize(self, opt):
        BaseDataLoader.initialize(self, opt)
        self.dataset = CreateDataset(opt)
        self.dataloader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size=opt.batchSize,
            shuffle=not opt.serial_batches,
            num_workers=int(opt.nThreads))

    def load_data(self):
        return self.dataloader

    def __len__(self):
        return min(len(self.dataset), self.opt.max_dataset_size)

def CreateDataLoader(opt):
    data_loader = CustomDatasetDataLoader()
    print(data_loader.name())
    data_loader.initialize(opt)
    return data_loader    


In [65]:
class options():
    def __init__(self):
        self.dataroot='/blanca/training_datasets/pix2pix/'
        self.dataset_list=['images_target_clean_classified', 'video_target_clean_classified']
        self.isTrain=True
        self.batchSize=1
        self.loadSize=128
        self.fineSize=128
        self.label_nc=35
        self.output_nc=3
        self.fill='W&B'
        self.serial_batches=False
        self.nThreads=16
        self.max_dataset_size=float("inf")
        # for training
        self.continue_train=False
        self.which_epoch='latest'
        self.phase='train'
        self.niter=50
        self.niter_decay=50
        self.beta1=0.5
        self.lr=0.0002
        
        # for the classifier
        self.class_nc=4
        self.n_layers_C=3
        self.num_C=1
        
opt = options()

In [None]:
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader)
print('#training images = %d' % dataset_size)

In [66]:
model = set_classifier_model()

total_steps = (start_epoch-1) * dataset_size + epoch_iter
for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
    epoch_start_time = time.time()
    if epoch != start_epoch: epoch_iter = epoch_iter % dataset_size
    for i, data in enumerate(dataset, start=epoch_iter):
        iter_start_time = time.time()
        total_steps += opt.batchSize
        epoch_iter += opt.batchSize

        # whether to collect output images
        save_fake = total_steps % opt.display_freq == 0

        ############## Forward Pass ######################
#         losses, generated = model(Variable(data['input']), Variable(data['inst']), 
#                                   Variable(data['target']), Variable(data['feat']), 
#                                   Variable(data['target4C']),
#                                   Variable(data['label']),
#                                   infer=save_fake)

        # sum per device losses

RuntimeError: cuda runtime error (2) : out of memory at /opt/conda/conda-bld/pytorch_1512386481460/work/torch/lib/THC/generic/THCStorage.cu:58

In [16]:
import sys
sys.path.append('/blanca/utils')
import libraries
from utils_global import *

path_g = '/blanca/training_datasets/pix2pix/images_target_clean_classified/good_fit'
path_b = '/blanca/training_datasets/pix2pix/images_target_clean_classified/bad_fit'
path_g = Path(path_g)
path_b = Path(path_b)
path_glist = list(path_g.glob('*.png')); print(len(path_glist))
path_blist = list(path_b.glob('*.png')); print(len(path_blist))



1290
935


In [59]:
def set_classifier_model():
        import networks
        use_gpu = torch.cuda.is_available()
        
        model = networks.DenseNetMulti(nchannels=opt.class_nc)
        model.initialize(opt)
        if use_gpu: model = torch.nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count()))).cuda()

        pretrained_1 = 'fitting_classifier/checkpoints_densenet_test/test_clean/model_best_c4_0.742152_0.018293_1519316.pt'      
        pretrained_2 = 'fitting_classifier/checkpoints_densenet_test/test_clean/model_best_c4_0.829596_0.023838_1519321.pt'   
        pretrained_3 = 'fitting_classifier/checkpoints_densenet_test/test_clean/model_best_c4_0.849776_0.022063_1519323.pt'
        
        pretrained = '/blanca/project/wip/pix2pixHDX_class-master/models/fitting_classifier/checkpoints_densenet_test/test_clean/model_best_c4_0.849776_0.022063_1519323.pt'
        for param in model.parameters(): param.requires_grad = False
        
        # uncomment below if  want to add pretrained   
        checkpoint = pretrained
        if checkpoint: model.load_state_dict(torch.load(checkpoint + 'h.tar'))  
        return model