In [1]:
from pathlib import Path
import correctingagent.world.world_generation as world_generation
from correctingagent.util.colour_dict import colour_dict
from correctingagent.util.colour_dict import fruit_dict
import random

In [2]:
dataset_name = 'fruit-4'
data_path = Path('/home/yucheng/Desktop/project/correcting-agent/data')
top_path = data_path / dataset_name

In [29]:
dataset_path = top_path / f'{dataset_name}{num_datasets}'
os.makedirs(dataset_path, exist_ok=True)

In [30]:
dataset_path

PosixPath('/home/yucheng/Desktop/project/correcting-agent/data/fruit-4/fruit-40')

In [3]:
import numpy as np
import json

In [7]:
test = [('purple', np.array([0.732085  , 0.92597844, 0.99624985])), ('orange', np.array([0.08524697, 0.90982403, 0.82954605]))]

In [8]:
fruit_dic = {}
for i, (colour, hsv) in enumerate(test):
    if colour == 'purple':
        new = np.array([1, 2, 3, 4])
        fruit_dic[f"b{i}"] = new

In [9]:
colour_object_dict = {f"b{i}": tuple(hsv) for i, (colour, hsv) in enumerate(test)}

In [10]:
colour_object_dict

{'b0': (0.732085, 0.92597844, 0.99624985),
 'b1': (0.08524697, 0.90982403, 0.82954605)}

In [11]:
fruit_dic

{'b0': array([1, 2, 3, 4])}

In [12]:
directory = Path("/home/yucheng/Desktop/project/correcting-agent/data/tmp")
json_name = directory / f"colours.json"
with open(json_name, 'w') as f:
    json.dump(colour_object_dict, f)

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm

import os
import os.path as osp
import glob
import cv2
import glob

In [5]:
class VAE(nn.Module):
    def __init__(self, alpha=1, beta=1, gamma=1, latent_n=1, groups={}, device="cpu"):
        super(VAE, self).__init__()
        layers = []
        self.latent_n = latent_n
        self.groups = groups
        self.groups_n = len(groups.keys())
        self.device = device
        
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma

        # IMAGE ENCODER
        self.conv_channels = [32,32,64,64]
        self.dense_channels = [1024, 32]
        self.deconv_channels = [64, 64]

        kernel_size=7
        self.encoder_conv_0 = nn.Conv2d(3, self.conv_channels[0], kernel_size, padding=3, stride=2) # (32, 32)
        kernel_size=5
        self.encoder_conv_1 = nn.Conv2d(self.conv_channels[0], self.conv_channels[1], kernel_size, padding=2, stride=2) # (16, 16)
        kernel_size=3
        self.encoder_conv_2 = nn.Conv2d(self.conv_channels[1], self.conv_channels[2], kernel_size, padding=1, stride=2) # (8, 8)
        self.encoder_conv_3 = nn.Conv2d(self.conv_channels[2], self.conv_channels[3], kernel_size, padding=1, stride=2) # (4, 4)
        
        self.encoder_dense_0 = nn.Linear(self.dense_channels[0], self.dense_channels[1])
        self.encoder_mu = nn.Linear(self.dense_channels[1], self.latent_n)
        self.encoder_ln_var = nn.Linear(self.dense_channels[1], self.latent_n)

        # IMAGE DECONV DECODER
        self.decoder_dense_0 = nn.Linear(self.latent_n, self.dense_channels[1])
        self.decoder_dense_1 = nn.Linear(self.dense_channels[1], self.dense_channels[0])
        self.decoder_conv_3 = nn.Conv2d(self.conv_channels[3], self.conv_channels[2], kernel_size, padding=1)
        self.decoder_conv_2 = nn.Conv2d(self.conv_channels[2], self.conv_channels[1], kernel_size, padding=1)
        self.decoder_conv_1 = nn.Conv2d(self.conv_channels[1], self.conv_channels[1], kernel_size, padding=1)
        self.decoder_output_img = nn.Conv2d(self.conv_channels[1], 3, kernel_size, padding=1)

        # CLASSIFIERS
        self.classifiers = nn.ModuleList([nn.Linear(1, len(items)) for key, items in self.groups.items()])
        
        self.encoder = [self.encoder_conv_0,
                        self.encoder_conv_1,
                        self.encoder_conv_2,
                        self.encoder_conv_3,
                        self.encoder_dense_0,
                        self.encoder_mu,
                        self.encoder_ln_var]
        
        self.decoder = [self.decoder_dense_0,
                        self.decoder_dense_1,
                        self.decoder_conv_3,
                        self.decoder_conv_2,
                        self.decoder_conv_1,
                        self.decoder_output_img]
        
        self.init_weights()
    
    def init_weights(self):
        for i in range(len(self.encoder)):
            self.encoder[i].weight.data.normal_(0, 0.01)
            
        for i in range(len(self.decoder)):
            self.decoder[i].weight.data.normal_(0, 0.01)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std
    
    def encode(self, x):
        
        conv_0_encoded = F.leaky_relu(self.encoder_conv_0(x))
        conv_1_encoded = F.leaky_relu(self.encoder_conv_1(conv_0_encoded))
        conv_2_encoded = F.leaky_relu(self.encoder_conv_2(conv_1_encoded))
        conv_3_encoded = F.leaky_relu(self.encoder_conv_3(conv_2_encoded))

        reshaped_encoded = torch.flatten(conv_3_encoded, start_dim=1)
        dense_0_encoded = F.leaky_relu(self.encoder_dense_0(reshaped_encoded))
        mu = self.encoder_mu(dense_0_encoded)
        logvar = self.encoder_ln_var(dense_0_encoded)
        
        z = self.reparameterize(mu, logvar)
        
        return z, mu, logvar
    
    def decode(self, z):
        
        dense_0_decoded = self.decoder_dense_0(z)
        dense_1_decoded = self.decoder_dense_1(dense_0_decoded)
        reshaped_decoded = dense_1_decoded.view((len(dense_1_decoded), self.conv_channels[-1], 4, 4))
        up_4_decoded = torch.nn.Upsample(scale_factor=2)(reshaped_decoded)
        deconv_3_decoded = F.relu(self.decoder_conv_3(up_4_decoded))
        up_3_decoded = torch.nn.Upsample(scale_factor=2)(deconv_3_decoded)
        deconv_2_decoded = F.relu(self.decoder_conv_2(up_3_decoded))
        up_2_decoded = torch.nn.Upsample(scale_factor=2)(deconv_2_decoded)
        deconv_1_decoded = F.relu(self.decoder_conv_1(up_2_decoded))
        up_1_decoded = torch.nn.Upsample(scale_factor=2)(deconv_1_decoded)
        out_img = self.decoder_output_img(up_1_decoded)
        
        return torch.sigmoid(out_img)
    
    def predict_labels(self, z, softmax=False):
        result = []
        
        for i in range(self.groups_n):
            prediction = self.classifiers[i](z[:, i, None])

            # need the check because the softmax_cross_entropy has a softmax in it
            if softmax:
                result.append(F.softmax(prediction))
            else:
                result.append(prediction)

        return result
    
    def get_latent(self, x):
        mu, logvar, _ = self.encode(x)
        z = self.reparameterize(mu, logvar)
        
        return z
    
    def forward(self, x):
        z, mu, logvar = self.encode(x)
        # img_out = self.sp_decode(z)
        img_out = self.decode(z)
        labels_out = self.predict_labels(z)
        
        return img_out, labels_out, mu, logvar 
    
    def get_loss(self):
        
        def loss(img_in, img_out, labels_in, labels_out, mu, logvar):
            
            rec = nn.MSELoss(reduction="none")(img_out, img_in)
            rec = torch.mean(torch.sum(rec.view(rec.shape[0], -1), dim=-1))

            label = 0
            for i in range(self.groups_n):
#                 print(i)
#                 print('label out')
#                 print(labels_out[i])
#                 print('----')
#                 print('label in')
#                 print(labels_in)
#                 print('----')
#                 label += nn.CrossEntropyLoss(ignore_index=100)(labels_out[i], labels_in[:, i])
                label += nn.CrossEntropyLoss(ignore_index=100)(labels_out[i], labels_in)
                
#                 print(label)
            kld = (((-0.5) * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())) / img_in.shape[0])

            rec *= self.alpha
            kld *= self.beta
            label *= self.gamma

            return rec + label + kld, rec, label, kld
    
        return loss

In [6]:
def fruit_loader(num=400):
    train_path = '/home/yucheng/Desktop/project/weak_label_lfd/fruit/Train/*'
    test_path = '/home/yucheng/Desktop/project/weak_label_lfd/fruit/test/*'
    crop_size = 64
    
    training_fruit_img = []
    training_label = []
    test_fruit_img = []
    test_label = []
#     total_fruit_img = []
#     total_label = []
    
    # load training images(#num per class)
    for dir_path in glob.glob(train_path):
        img_label = dir_path.split("/")[-1]
        count = 0
        for image_path in glob.glob(os.path.join(dir_path,"*.jpg")):
            image = cv2.imread(image_path,cv2.IMREAD_COLOR)
            image = cv2.resize(image, (crop_size, crop_size))
            image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
            training_fruit_img.append(image)
            training_label.append(img_label)
            count += 1
            if count == num:
                break
#             elif count < k+50:
#                 image = cv2.imread(image_path,cv2.IMREAD_COLOR)
#                 image = cv2.resize(image, (crop_size, crop_size))
#                 image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
#                 test_fruit_img.append(image)
#                 test_label.append(img_label)
#                 count = count + 1
#             else:
#                 break

    # load all test images
    for dir_path in glob.glob(test_path):
        img_label = dir_path.split("/")[-1]
        count = 0
        for image_path in glob.glob(os.path.join(dir_path,"*.jpg")):
            image = cv2.imread(image_path,cv2.IMREAD_COLOR)
            image = cv2.resize(image, (crop_size, crop_size))
            image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
            test_fruit_img.append(image)
            test_label.append(img_label)
            count += 1
            if count == 100:
                break 
            
    train_imgs = np.array(training_fruit_img)
    train_labels = np.array(training_label)
    test_imgs = np.array(test_fruit_img)
    test_labels = np.array(test_label)
    
    return train_imgs, train_labels, test_imgs, test_labels
#     return total_imgs, total_labels

In [7]:
def preprocessing(k, train_imgs, train_labels, test_imgs, test_labels,batchsize=32, readall=True, num=400):
    n_class = 10
    number_of_training = len(train_imgs) #total number
    number_of_test = len(test_imgs)
    train_n = k*n_class # how many training sample we need 
    
    if readall:
        train_indecies = np.random.choice(range(number_of_training), train_n, replace=False) # choose k*n_class number of training samples
        test_indecies = np.random.choice(range(number_of_test), 100, replace=False) # choose k number of test 
    else:
        train_indecies = []
        test_indecies = []
    
        for i in range(n_class):
            class_indecies = np.random.choice(range(num), k, replace=False) # random choose k number of training samples per class
            test_class_indecies = np.random.choice(range(100), 5, replace=False) # random pick 5 sample for each class
            train_indecies.extend(class_indecies + num*i)
            test_indecies.extend(test_class_indecies + 100*i)
    
    # print(train_indecies)

    label_to_id = {v:k for k,v in enumerate(np.unique(train_labels)) }
    id_to_label = {v:k for k,v in label_to_id.items() }

    label_id = np.array([label_to_id[i] for i in train_labels])
    one_hot_label = np.zeros((label_id.shape[0], n_class))
    one_hot_label[np.arange(label_id.shape[0]), label_id] = 1

    train_imgs = np.swapaxes(train_imgs, 1, 3) # (n, 3, 64, 64)
    train_imgs = train_imgs/255
    
    
#     label_to_id_t = {v:k for k,v in enumerate(np.unique(test_labels)) }
#     id_to_label_t = {v:k for k,v in label_to_id_t.items() }

    label_id_t = np.array([label_to_id[i] for i in test_labels])
    one_hot_label_t = np.zeros((label_id_t.shape[0], n_class))
    one_hot_label_t[np.arange(label_id_t.shape[0]), label_id_t] = 1

    test_imgs = np.swapaxes(test_imgs, 1, 3) # (n, 3, 64, 64)
    test_imgs = test_imgs/255
    
#     imgs = imgs.astype(np.float32)
#     one_hot_label = one_hot_label.astype(np.long)
    
    train_images = np.take(train_imgs, train_indecies, axis=0).astype(np.float32)
    train_labels = np.take(one_hot_label, train_indecies, axis=0).astype(np.long)
    label_id = np.take(label_id, train_indecies, axis=0).astype(np.long)
    # print(label_id)

    test_images = np.take(test_imgs, test_indecies, axis=0).astype(np.float32)
    test_labels = np.take(one_hot_label_t, test_indecies, axis=0).astype(np.long)
    label_id_t = np.take(label_id_t, test_indecies, axis=0).astype(np.long)
    #random shuffle training data
    state = np.random.get_state()
    np.random.shuffle(train_images)
    np.random.set_state(state)
    np.random.shuffle(train_labels)
    np.random.set_state(state)
    np.random.shuffle(label_id)
    
    train_dataset = torch.utils.data.TensorDataset(torch.tensor(train_images), torch.tensor(label_id))
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batchsize, shuffle=False, num_workers=0)
        
    test_dataset = torch.utils.data.TensorDataset(torch.tensor(test_images), torch.tensor(label_id_t))
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batchsize, shuffle=True, num_workers=0)
    
    
    return train_dataloader, train_images, train_labels, label_id, test_dataloader, test_images, test_labels, label_id_t

In [21]:
train_imgs_100, train_labels_100, test_imgs_100, test_labels_100 = fruit_loader()

In [9]:
test_imgs_100.shape

(1000, 64, 64, 3)

In [176]:
k = 400
if k == 5:
    batchsize = 1
else:
    batchsize = 32
num = 400

In [177]:
groups = {}
groups[0] = list(range(10))

PATH = "/home/yucheng/Desktop/project/correcting-agent/results_0128"
model_name = "model_k{}_0".format(k)
latent_n = 4
# net = VAE(latent_n=latent_n, groups=data_generator.groups)
net = VAE(latent_n=latent_n, groups=groups)
net.load_state_dict(torch.load(osp.join(PATH, model_name)))

<All keys matched successfully>

In [23]:
trainloader, train_images, train_labels, label_id, testloader, test_images, test_labels, label_id_t= preprocessing(k, train_imgs_100, train_labels_100, test_imgs_100, test_labels_100, batchsize=batchsize, readall=False, num=num)

In [13]:
imgs = test_images
labels = label_id_t

NameError: name 'test_images' is not defined

In [57]:
latent = {}
stats = {}
for group_idx in groups.keys():
    latent[group_idx] = {}
    stats[group_idx] = {}
    for label in range(len(groups[group_idx])):  
        indecies = [i for i, label_i in enumerate(labels) if label_i == label]
        print(indecies)
        filtered_data_imgs = np.take(imgs, indecies, axis=0).astype(np.float32)
    
        latent_out, _, _ = net.encode(torch.tensor(filtered_data_imgs))
        latent[group_idx][label] = latent_out.detach().numpy()

[35, 36, 37, 38, 39]
[30, 31, 32, 33, 34]
[40, 41, 42, 43, 44]
[15, 16, 17, 18, 19]
[5, 6, 7, 8, 9]
[45, 46, 47, 48, 49]
[25, 26, 27, 28, 29]
[20, 21, 22, 23, 24]
[0, 1, 2, 3, 4]
[10, 11, 12, 13, 14]


In [56]:
labels

array([8, 8, 8, 8, 8, 4, 4, 4, 4, 4, 9, 9, 9, 9, 9, 3, 3, 3, 3, 3, 7, 7,
       7, 7, 7, 6, 6, 6, 6, 6, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 2, 2, 2, 2,
       2, 5, 5, 5, 5, 5])

In [59]:
latent[0]

array([[ 2.3691063 ,  2.0686724 ,  0.11104929,  0.7617174 ],
       [ 2.5010726 ,  1.5249442 ,  1.1973306 , -0.98127985],
       [ 2.67693   ,  3.0622165 , -0.0385436 ,  0.2827496 ],
       [ 1.6710508 ,  0.8730894 , -0.37629336, -1.2154037 ],
       [ 2.8382    ,  2.8577678 ,  0.30587205,  0.27836552]],
      dtype=float32)

In [14]:
cc = {'apple': 2, 'banana': 2, 'blueberry': 0, 'corn': 2, 'eggplant': 3, 'kaki': 1, 'lemon': 0, 'mango': 0, 'orange': 0, 'pear': 0}

In [22]:
# imgs = test_imgs_100
imgs = np.swapaxes(test_imgs_100, 1, 3) # (n, 3, 64, 64)
imgs = imgs/255
labels = test_labels_100

In [23]:
label_to_id = {v.lower():k for k,v in enumerate(np.unique(test_labels_100)) }
id_to_label = {v:k for k,v in label_to_id.items() }

In [24]:
labels[0].lower()
id_to_label

{0: 'apple',
 1: 'banana',
 2: 'blueberry',
 3: 'corn',
 4: 'eggplant',
 5: 'kaki',
 6: 'lemon',
 7: 'mango',
 8: 'orange',
 9: 'pear'}

In [179]:
cc[labels[0].lower()]

select_10 = []
for (fruit, fnum) in cc.items():
    while not fnum == 0:
        select_10.append(label_to_id[fruit.lower()])
        fnum -= 1
select_10

[0, 0, 1, 1, 3, 3, 4, 4, 4, 5]

In [169]:
test_list = [0, 0, 1, 1, 3, 3, 4, 4, 4, 5]
incid = [np.random.randint(0,100) + 100*i for i in test_list]
incid

[55, 97, 120, 142, 310, 304, 496, 499, 481, 594]

In [170]:
filtered_data_imgs = np.take(imgs, incid, axis=0).astype(np.float32)
latent_out, _, _ = net.encode(torch.tensor(filtered_data_imgs))
latent_out

tensor([[ 5.0480,  0.8348, -0.1941, -0.0096],
        [ 4.7651,  1.1600,  0.5454,  0.0488],
        [-1.7374, -0.5336, -0.3445, -0.4323],
        [-1.4363, -1.3078,  1.0163, -0.2258],
        [-1.1887, -0.2771, -0.9389, -0.6221],
        [-0.1838,  0.0749,  0.4873,  0.0754],
        [ 1.5223,  0.8615,  0.4810, -1.2464],
        [ 2.1577,  1.6866, -0.4839, -0.9207],
        [ 2.4183,  2.7551,  0.1506,  0.4632],
        [ 0.4164,  1.1250, -0.6538,  1.6784]], grad_fn=<AddBackward0>)

In [183]:
fruit_dic = []
for i in range(10):
    fruit_name = id_to_label[select_10[i]]
    pred = latent_out[i].detach().numpy()
    fruit_dic.append((fruit_name, np.array(pred, dtype='float')))
fruit_dic

[('apple', array([ 5.0480237 ,  0.8347916 , -0.19412214, -0.00962634])),
 ('apple', array([4.76509047, 1.15995181, 0.54543334, 0.0487729 ])),
 ('banana', array([-1.73744595, -0.53363395, -0.34452271, -0.43234622])),
 ('banana', array([-1.43629956, -1.30781043,  1.01633394, -0.22584137])),
 ('corn', array([-1.18869901, -0.27705237, -0.93886608, -0.62213343])),
 ('corn', array([-0.1838163 ,  0.07486677,  0.48726872,  0.07544299])),
 ('eggplant', array([ 1.52232325,  0.86154848,  0.48096091, -1.24640119])),
 ('eggplant', array([ 2.1577189 ,  1.6866374 , -0.48394608, -0.92070025])),
 ('eggplant', array([2.41828609, 2.75510383, 0.15061529, 0.46316192])),
 ('kaki', array([ 0.41638625,  1.12501609, -0.65382934,  1.67839038]))]

#not this#

In [127]:
latent = {}
for group_idx in groups.keys():
    latent[group_idx] = {}
    for label in range(len(groups[group_idx])):  
        indecies = [i for i, label_i in enumerate(select_10) if label_i == label]
        print(indecies)
        if  not indecies == []:
            filtered_data_imgs = np.take(imgs, indecies, axis=0).astype(np.float32)
            latent_out, _, _ = net.encode(torch.tensor(filtered_data_imgs))
            latent[group_idx][label] = latent_out.detach().numpy()
latent

[0, 1]
[2, 3]
[]
[4, 5]
[6, 7, 8]
[9]
[]
[]
[]
[]


{0: {0: array([[ 4.7681403 ,  0.9798317 ,  1.3957351 ,  0.09522648],
         [ 4.7758036 ,  1.1752228 , -0.38502648,  0.12011665]],
        dtype=float32),
  1: array([[5.1187935 , 0.9028863 , 1.5332273 , 0.0834244 ],
         [5.265356  , 0.9243038 , 0.8544263 , 0.24710628]], dtype=float32),
  3: array([[ 4.897072  ,  0.88167876,  1.1739283 , -0.26181182],
         [ 4.884778  ,  0.875858  , -0.7620531 ,  0.01083781]],
        dtype=float32),
  4: array([[ 5.213238  ,  1.1264787 , -1.6092746 ,  0.22209597],
         [ 5.1575537 ,  0.9007012 ,  0.2085421 ,  0.41553402],
         [ 5.3002877 ,  0.974637  ,  0.8561054 ,  0.43721384]],
        dtype=float32),
  5: array([[ 4.9558244 ,  1.2688925 , -1.5719892 ,  0.13183293]],
        dtype=float32)}}

In [61]:
latent

{0: {0: array([[4.869334  , 0.94339395, 0.2182904 , 0.10302754],
         [5.049664  , 1.1624012 , 1.5175778 , 0.16589102]], dtype=float32),
  1: array([[ 5.212882  ,  0.94136167,  1.4061406 ,  0.05412238],
         [ 4.9311957 ,  0.91960835, -1.0083135 ,  0.23742832]],
        dtype=float32),
  3: array([[ 4.990582  ,  0.9827923 , -1.2049844 , -0.21127099],
         [ 4.8523054 ,  0.9534163 , -1.3470364 ,  0.07308765]],
        dtype=float32),
  4: array([[ 5.1115227 ,  0.98608345, -1.0387119 ,  0.12370645],
         [ 5.283029  ,  0.8838004 ,  0.6265168 ,  0.4067654 ],
         [ 5.596168  ,  0.89492565, -1.2474103 ,  0.40614352]],
        dtype=float32),
  5: array([[5.06676   , 1.1875188 , 0.02973331, 0.21877374]], dtype=float32)}}

In [47]:
fruit_dic = []
for fruit_id, pred_list in latent[0].items():
    for pred in pred_list:
        fruit_name = id_to_label[fruit_id]
        fruit_dic.append((fruit_name, pred))
fruit_dic

[('apple', array([4.7008495, 0.9251889, 0.9384835, 0.1172428], dtype=float32)),
 ('apple',
  array([ 4.9910173 ,  1.1615291 , -1.1792406 ,  0.16839597], dtype=float32)),
 ('banana',
  array([4.842707 , 0.9735563, 0.850629 , 0.0160583], dtype=float32)),
 ('banana',
  array([ 5.1717334 ,  0.9563578 , -0.47481042,  0.2953953 ], dtype=float32)),
 ('corn',
  array([ 4.773913  ,  0.8409624 , -1.4707614 , -0.28657475], dtype=float32)),
 ('corn',
  array([ 4.5907717 ,  0.796672  , -0.45461065,  0.00692413], dtype=float32)),
 ('eggplant',
  array([5.0864363 , 1.1268605 , 0.8535908 , 0.16164148], dtype=float32)),
 ('eggplant',
  array([ 5.29296   ,  0.9293656 , -0.07538732,  0.3935193 ], dtype=float32)),
 ('eggplant',
  array([5.3729153, 0.9852488, 0.3370942, 0.3613001], dtype=float32)),
 ('kaki',
  array([4.9832363 , 1.2316736 , 0.22599924, 0.08880075], dtype=float32))]

In [100]:
random.shuffle(fruit_dic)

In [103]:
list(label_to_id.keys())

['apple braeburn',
 'banana',
 'blueberry',
 'corn',
 'eggplant',
 'kaki',
 'lemon',
 'mango',
 'orange',
 'pear']

In [104]:
{'apple': ['appleb', 'apple'],
'banana' : ['banana'],
'blueberry' : ['blueberry'],
'corn' : ['corn'],
'eggplant': ['eggplant'],
'kaki': ['kaki'],
'lemon': ['lemon'],
'mango': ['mango'],
'orange' : ['orange'],
'pear': ['pear']}.keys()

dict_keys(['apple', 'banana', 'blueberry', 'corn', 'eggplant', 'kaki', 'lemon', 'mango', 'orange', 'pear'])

In [184]:
def generate_colour_generator(mean=0, std=0):
    """ Creates a colour generator function which generates random HSV colours based on the provided mean and std

    the SV channels will always use the same mean and std, ensuring that colours are not too dark or too light

    :param mean:
    :param std:
    :return: colour_generator function
    """
    def colour_generator():
        return np.array((((np.random.randn() * std + mean) % 360) / 3.6 , 100 - np.abs(np.random.randn() * 10), 100 - np.abs(np.random.rand() * 20)))/100
    return colour_generator

# These mean and std values seem to generate good values for each colour which all look sensibly like the specified colour
colour_values = {"red": (0, 5),
                 "orange": (30, 5),
                 "yellow": (58, 2),
                 "green": (120, 9),
                 "blue": (220, 13),
                 "purple": (270, 9),
                 "pink": (315, 9)}
# This dict maps a colour to its colour generator
# So to generate red use colour_generators['red']()
colour_generators = {colour: generate_colour_generator(*values) for colour, values in colour_values.items()}

In [202]:
import random

In [210]:
sss = random.sample(colour_values.keys(), 1)

In [212]:
colour_generators[sss[0]]()

array([0.90869884, 0.95102625, 0.91671468])

In [221]:
fruit_dic = []
for i in range(10):
    fruit_name = id_to_label[select_10[i]]
    pred = latent_out[i].detach().numpy()
    colour = random.sample(colour_values.keys(), 1)
    long_data = np.hstack((pred, colour_generators[colour[0]]()))
    fruit_dic.append((fruit_name, np.array(long_data, dtype='float')))

In [222]:
fruit_dic

[('apple',
  array([ 5.0480237 ,  0.8347916 , -0.19412214, -0.00962634,  0.57904829,
          0.85254352,  0.86180451])),
 ('apple',
  array([4.76509047, 1.15995181, 0.54543334, 0.0487729 , 0.00695483,
         0.888134  , 0.98691595])),
 ('banana',
  array([-1.73744595, -0.53363395, -0.34452271, -0.43234622,  0.63658747,
          0.99990141,  0.88476406])),
 ('banana',
  array([-1.43629956, -1.30781043,  1.01633394, -0.22584137,  0.1557267 ,
          0.94510291,  0.98111312])),
 ('corn',
  array([-1.18869901, -0.27705237, -0.93886608, -0.62213343,  0.30646622,
          0.79944076,  0.82605357])),
 ('corn',
  array([-0.1838163 ,  0.07486677,  0.48726872,  0.07544299,  0.00940821,
          0.99968653,  0.9864536 ])),
 ('eggplant',
  array([ 1.52232325,  0.86154848,  0.48096091, -1.24640119,  0.01762886,
          0.9500535 ,  0.95823115])),
 ('eggplant',
  array([ 2.1577189 ,  1.6866374 , -0.48394608, -0.92070025,  0.73739797,
          0.98862877,  0.96928054])),
 ('eggplant',
  a