In [1]:
def getImg(archive, mode, categories, dataset, data_path, 
           cat_test=None, occ_level='ZERO', occ_type=None, bool_load_occ_mask = False):

    if mode == 'train':
        train_imgs = []
        train_labels = []
        train_masks = []
        for category in categories:
            if dataset == 'pascal3d+':
                if occ_level == 'ZERO':
                    filelist = 'pascal3d+_occ/' + category + '_imagenet_train' + '.txt'
                    img_dir = 'pascal3d+_occ/TRAINING_DATA/' + category + '_imagenet'
            elif dataset == 'coco':
                if occ_level == 'ZERO':
                    img_dir = 'coco_occ/{}_zero'.format(category)
                    filelist = 'coco_occ/{}_{}_train.txt'.format(category, occ_level)

            with archive.open(filelist, 'r') as fh:
                contents = fh.readlines()
            img_list = [cc.strip().decode('ascii') for cc in contents]
            label = categories.index(category)
            for img_path in img_list:
                if dataset=='coco':
                    if occ_level == 'ZERO':
                        img = img_dir + '/' + img_path + '.jpg'
                    else:
                        img = img_dir + '/' + img_path + '.JPEG'
                else:
                    img = img_dir + '/' + img_path + '.JPEG'
                occ_img1 = []
                occ_img2 = []
                train_imgs.append(img)
                train_labels.append(label)
                train_masks.append([occ_img1,occ_img2])
        
        return train_imgs, train_labels, train_masks

    else:
        test_imgs = []
        test_labels = []
        occ_imgs = []
        for category in cat_test:
            if dataset == 'pascal3d+':
                filelist = data_path + 'pascal3d+_occ/' + category + '_imagenet_occ.txt'
                img_dir = data_path + 'pascal3d+_occ/' + category + 'LEVEL' + occ_level
                if bool_load_occ_mask:
                    if  occ_type=='':
                        occ_mask_dir = 'pascal3d+_occ/' + category + 'LEVEL' + occ_level+'_mask_object'
                    else:
                        occ_mask_dir = 'pascal3d+_occ/' + category + 'LEVEL' + occ_level+'_mask'
                    occ_mask_dir_obj = 'pascal3d+_occ/0_old_masks/'+category+'_imagenet_occludee_mask/'
            elif dataset == 'coco':
                if occ_level == 'ZERO':
                    img_dir = 'coco_occ/{}_zero'.format(category)
                    filelist = 'coco_occ/{}_{}_test.txt'.format(category, occ_level)
                else:
                    img_dir = 'coco_occ/{}_occ'.format(category)
                    filelist = 'coco_occ/{}_{}.txt'.format(category, occ_level)

#             if os.path.exists(filelist):
            with archive.open(filelist, 'r') as fh:
                contents = fh.readlines()
            img_list = [cc.strip().decode('ascii') for cc in contents]
            label = categories.index(category)
            for img_path in img_list:
                if dataset != 'coco':
                    if occ_level=='ZERO':
                        img = img_dir + occ_type + '/' + img_path[:-2] + '.JPEG'
                        occ_img1 = []
                        occ_img2 = []
                    else:
                        img = img_dir + occ_type + '/' + img_path + '.JPEG'
                        if bool_load_occ_mask:
                            occ_img1 = occ_mask_dir + '/' + img_path + '.JPEG'
                            occ_img2 = occ_mask_dir_obj + '/' + img_path + '.png'
                        else:
                            occ_img1 = []
                            occ_img2 = []

                else:
                    img = img_dir + occ_type + '/' + img_path + '.jpg'
                    occ_img1 = []
                    occ_img2 = []

                test_imgs.append(img)
                test_labels.append(label)
                occ_imgs.append([occ_img1,occ_img2])
#             else:
#                 print('FILELIST NOT FOUND: {}'.format(filelist))
        return test_imgs, test_labels, occ_imgs


def imgLoader(archive, img_path,mask_path,bool_resize_images=True,bool_square_images=False):
    
    archive_img_path = archive.open(img_path)
    input_image = Image.open(archive_img_path)
    if bool_resize_images:
        if bool_square_images:
            input_image.resize((224,224),Image.ANTIALIAS)
        else:
            sz=input_image.size
            min_size = np.min(sz)
            if min_size!=224:
                input_image = input_image.resize((np.asarray(sz) * (224 / min_size)).astype(int),Image.ANTIALIAS)
    preprocess =  transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    img = preprocess(input_image)

    if mask_path[0]:
        f = archive.open(mask_path[0])
        mask1 = np.array(Image.open(f))
        f.close()
        mask1 = myresize(mask1, 224, 'short')
        try:
            mask2 = cv2.imread(mask_path[1])[:, :, 0]
            mask2 = mask2[:mask1.shape[0], :mask1.shape[1]]
        except:
            mask = mask1
        try:
            mask = ((mask1 == 255) * (mask2 == 255)).astype(np.float)
        except:
            mask = mask1
    else:
        mask = np.ones((img.shape[0], img.shape[1])) * 255.0

    mask = torch.from_numpy(mask)
    return img,mask


class Imgset():
    def __init__(self, archive, imgs, masks, labels, loader,bool_square_images=False):
        self.archive = archive
        self.images = imgs
        self.masks 	= masks
        self.labels = labels
        self.loader = loader
        self.bool_square_images = bool_square_images

    def __getitem__(self, index):
        fn = self.images[index]
        label = self.labels[index]
        mask = self.masks[index]
        img,mask = self.loader(self.archive,fn,mask,bool_resize_images=True,bool_square_images=self.bool_square_images)
        return img, mask, label

    def __len__(self):
        return len(self.images)

def save_checkpoint(state, filename, is_best):
    if is_best:
        print("=> Saving new checkpoint")
        torch.save(state, filename)
    else:
        print("=> Validation Accuracy did not improve")

In [2]:
import glob
import pickle
import os
import zipfile
from PIL import Image
import pdb

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import cv2

from CompositionalNets.Code.vMFMM import *
from CompositionalNets.Initialization_Code.config_initialization import dataset, categories, vc_num, data_path, cat_test, device_ids, Astride, Apad, Arf, vMF_kappa, layer,init_path, nn_type, dict_dir, offset, extractor
from CompositionalNets.Code.helpers import myresize

In [53]:
# Number of images to train on per category
# Clustering ignored after this threshold is met
img_per_cat = 1000

# Number of feature vectors to sample from each image's feature map
samp_size_per_img = 20 

imgs_par_cat = np.zeros(len(categories))
bool_load_existing_cluster = False
bins = 4

archive = zipfile.ZipFile(os.path.join(data_path, 'CompNet_data.zip'))
occ_level = 'ZERO'
occ_type = ''
imgs, labels, masks = getImg(archive, 'train', categories, dataset, data_path, cat_test, occ_level, occ_type, bool_load_occ_mask=False)
imgset = Imgset(archive, imgs, masks, labels, imgLoader, bool_square_images=False)
data_loader = DataLoader(dataset=imgset, batch_size=1, shuffle=False)
nimgs = len(imgs)

loc_set = []
feat_set = []
nfeats = 0
for ii,data in enumerate(data_loader):
    input, mask, label = data
    if np.mod(ii,500)==0:
        print('{} / {}'.format(ii,len(imgs)))

    fname = imgs[ii]
    category = labels[ii]

    if imgs_par_cat[label]<img_per_cat:
        with torch.no_grad():
            tmp = extractor(input.cuda(device_ids[0]))[0].detach().cpu().numpy()
        height, width = tmp.shape[1:3]
        img = cv2.imread(imgs[ii])

        # Crop image by some offset
        tmp = tmp[:,offset:height - offset, offset:width - offset]

        # Flatten image at each channel
        gtmp = tmp.reshape(tmp.shape[0], -1)
        if gtmp.shape[1] >= samp_size_per_img:
            rand_idx = np.random.permutation(gtmp.shape[1])[:samp_size_per_img]
        else:
            rand_idx = np.random.permutation(gtmp.shape[1])[:samp_size_per_img - gtmp.shape[1]]
            #rand_idx = np.append(range(gtmp.shape[1]), rand_idx)
        tmp_feats = gtmp[:, rand_idx].T

        cnt = 0
        for rr in rand_idx:
            ihi, iwi = np.unravel_index(rr, (height - 2 * offset, width - 2 * offset))
            hi = (ihi+offset)*(input.shape[2]/height)-Apad
            wi = (iwi + offset)*(input.shape[3]/width)-Apad
            #hi = Astride * (ihi + offset) - Apad
            #wi = Astride * (iwi + offset) - Apad

            #assert (hi >= 0)
            #assert (wi >= 0)
            #assert (hi <= img.shape[0] - Arf)
            #assert (wi <= img.shape[1] - Arf)
            loc_set.append([category, ii, hi,wi,hi+Arf,wi+Arf])
            feat_set.append(tmp_feats[cnt,:])
            cnt+=1

        imgs_par_cat[label]+=1


feat_set = np.asarray(feat_set)
loc_set = np.asarray(loc_set).T

print(feat_set.shape)
model = vMFMM(vc_num, 'k++')
model.fit(feat_set, vMF_kappa, max_it=150)
with open(dict_dir+'dictionary_{}_{}.pickle'.format(layer,vc_num), 'wb') as fh:
    pickle.dump(model.mu, fh)


num = 50
SORTED_IDX = []
SORTED_LOC = []
for vc_i in range(vc_num):
    sort_idx = np.argsort(-model.p[:, vc_i])[0:num]
    SORTED_IDX.append(sort_idx)
    tmp=[]
    for idx in range(num):
        iloc = loc_set[:, sort_idx[idx]]
        tmp.append(iloc)
    SORTED_LOC.append(tmp)

with open(dict_dir + 'dictionary_{}_{}_p.pickle'.format(layer,vc_num), 'wb') as fh:
    pickle.dump(model.p, fh)
p = model.p

print('save top {0} images for each cluster'.format(num))
example = [None for vc_i in range(vc_num)]
out_dir = os.path.join(dict_dir, f'cluster_images_{layer}_{vc_num}')
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

for vc_i in range(vc_num):
    patch_set = np.zeros(((Arf**2)*3, num)).astype('uint8')
    sort_idx = SORTED_IDX[vc_i]#np.argsort(-p[:,vc_i])[0:num]
    opath = os.path.join(out_dir, str(vc_i))
    if not os.path.exists(opath):
        os.makedirs(opath)
    locs=[]
    for idx in range(num):
        iloc = loc_set[:,sort_idx[idx]]
        category = iloc[0]
        loc = iloc[1:6].astype(int)
        if not loc[0] in locs:
            locs.append(loc[0])
#             pdb.set_trace()
            
            archive_img_path = archive.open(imgs[int(loc[0])])
            img = np.array(Image.open(archive_img_path))
            img = myresize(img, 224, 'short')
            patch = img[loc[1]:loc[3], loc[2]:loc[4], :]
            #patch_set[:,idx] = patch.flatten()
            if patch.size:
                out_path = os.path.join(opath, f'{str(idx)}.JPEG')
                cv2.imwrite(out_path, patch)
    #example[vc_i] = np.copy(patch_set)
    if vc_i%10 == 0:
        print(vc_i)

# print summary for each vc
#if layer=='pool4' or layer =='last': # somehow the patches seem too big for p5
for c in range(vc_num):
    iidir = os.path.join(out_dir, str(c))
    files = glob.glob(iidir+'*.JPEG')
    width = 100
    height = 100
    canvas = np.zeros((0,4*width,3))
    cnt = 0
    for jj in range(4):
        row = np.zeros((height,0,3))
        ii=0
        tries=0
        next=False
        for ii in range(4):
            if (jj*4+ii)< len(files):
                img_file = files[jj*4+ii]
#                 archive_img_path = archive.open(img_file)
#                 img = np.array(Image.open(archive_img_path))
                
                if os.path.exists(img_file):
                    img = cv2.imread(img_file)
                img = cv2.resize(img, (width,height))
            else:
                img = np.zeros((height, width, 3))
            row = np.concatenate((row, img), axis=1)
        canvas = np.concatenate((canvas,row),axis=0)
    cv2.imwrite(os.path.join(out_dir, f'{str(c)}.JPEG'),canvas)

0
10
20
30
40
50
60
70
80
90
100
110
120
130
140
150
160
170
180
190
200
210
220
230
240
250
260
270
280
290
300
310
320
330
340
350
360
370
380
390
400
410
420
430
440
450
460
470
480
490
500
510


In [5]:
dictfile = os.path.join(dict_dir, f'dictionary_{layer}_{vc_num}.pickle')
with open(dictfile, 'rb') as fh:
    centers = pickle.load(fh)
bool_pytroch = True
bool_plot_view_p3d=False

mixdir = init_path + 'mix_model_vmf_{}_EM_all/'.format(dataset)
if not os.path.exists(mixdir):
    os.makedirs(mixdir)
occ_level='ZERO'
occ_type=''
spectral_split_thresh=0.1

In [9]:
from CompositionalNets.Initialization_Code.Learn_mix_model_vMF_view import learn_mix_model_vMF

loading /project/6052161/mattlk/workplace/CompNet-Medical/CompositionalNets/models/init_vgg/dictionary_vgg/dictionary_pool5_512.pickle


In [None]:
learn_mix_model_vMF