In [1]:
import copy
import numpy as np
from sklearn.utils import shuffle
from torchvision import datasets, transforms
from torch.utils.data import ConcatDataset, Dataset
import torch
import random
from sklearn.preprocessing import StandardScaler

In [2]:
def get_selected_classes(target_classes):
    classes_Y = [i for i in range(100)]
    #print(classes_Y)
    selected_classes = np.random.choice(classes_Y, target_classes,replace=False)
    #print(selected_classes)
    
    return selected_classes

def V2_get_continual_ember_class_data(data_dir, train=True):
    
    if train:
        data_dir = data_dir + '/'
        XY_train = np.load(data_dir + 'XY_train.npz')
        X_tr, Y_tr = XY_train['X_train'], XY_train['Y_train']

        return X_tr, Y_tr
    else:
        data_dir = data_dir + '/'
        XY_test = np.load(data_dir + 'XY_test.npz')
        X_test, Y_test = XY_test['X_test'], XY_test['Y_test']

        return X_test, Y_test 


def get_ember_selected_class_data(data_dir, selected_classes, train=True):
    
    
    if train:
        all_X, all_Y = V2_get_continual_ember_class_data(data_dir, train=True)
    else:
        all_X, all_Y = V2_get_continual_ember_class_data(data_dir, train=False)
    
    X_ = []
    Y_ = []

    for ind, cls in enumerate(selected_classes):
        get_ind_cls = np.where(all_Y == cls)
        cls_X = all_X[get_ind_cls]
        #cls_Y = all_Y[get_ind_cls]

        #assert len(cls_Y) == len(cls_X)

        for j in range(len(cls_X)):
            X_.append(cls_X[j])
            Y_.append(ind)

    #from sklearn.utils import shuffle        
    X_ = np.float32(np.array(X_))
    Y_ = np.array(Y_, dtype=np.int64)
    X_, Y_ = shuffle(X_, Y_)

    if train:
        print(f' Training data X {X_.shape} Y {Y_.shape}')
    else:
        print(f' Test data X {X_.shape} Y {Y_.shape}')
    
    return X_, Y_


class malwareSubDataset(Dataset):
    '''To sub-sample a dataset, taking only those samples with label in [sub_labels].

    After this selection of samples has been made, it is possible to transform the target-labels,
    which can be useful when doing continual learning with fixed number of output units.'''
    
    def __init__(self, original_dataset, orig_length_features,\
                 target_length_features, sub_labels):
        super().__init__()
        #print(target_transform)
        self.dataset, self.origlabels = original_dataset
        self.orig_length_features = orig_length_features
        self.target_length_features = target_length_features
        
        self.sub_indeces = []
        for index in range(len(self.dataset)):
            label = self.origlabels[index]
            
            if label in sub_labels:
                self.sub_indeces.append(index)

    def __len__(self):
        return len(self.sub_indeces)

    def __getitem__(self, index):
        
        self.padded_features = np.zeros(self.target_length_features - self.orig_length_features, dtype=np.float32)
        sample = np.concatenate((self.dataset[self.sub_indeces[index]],self.padded_features))
        target = self.origlabels[self.sub_indeces[index]]
        
        return (sample, target)  


def get_malware_multitask_experiment(target_classes,\
                                     orig_feats_length, target_feats_length,\
                                     scenario, tasks,\
                                     data_dir="../../../ember2018/top_class_bases/top_classes_100"):
    
    num_class = target_classes

    classes_per_task = 5

    selected_classes = get_selected_classes(target_classes)

    print(selected_classes)

    X_train, Y_train = get_ember_selected_class_data(data_dir, selected_classes, train=True)
    X_test, Y_test = get_ember_selected_class_data(data_dir, selected_classes, train=False)

    
    standardization = StandardScaler()
    standard_scaler = standardization.fit(X_train)
    X_train = standard_scaler.transform(X_train)
    X_test = standard_scaler.transform(X_test)  

    ember_train, ember_test = (X_train, Y_train), (X_test, Y_test)


    first_task = list(range(50)) #[0, 1, .., 49]

    labels_per_task = list([first_task]) + [ 
        list(np.array(range(classes_per_task)) + classes_per_task * task_id) for task_id in range(10,20)]

    print(labels_per_task)

#     # split them up into sub-tasks
#     train_datasets = []
#     test_datasets = []
#     for labels in labels_per_task:
#         train_datasets.append(malwareSubDataset(ember_train, orig_feats_length,\
#                                                 target_feats_length, labels))
#         test_datasets.append(malwareSubDataset(ember_test, orig_feats_length,\
#                                                target_feats_length, labels))




    # Return tuple of train-, validation- and test-dataset, and number of classes per task
    #return ((train_datasets, test_datasets), classes_per_task, labels_per_task)
    return (ember_train, ember_test, classes_per_task)


target_classes = 100
scenario = 'class'
orig_feats_length, target_feats_length = 2381, 2381
tasks = 11

# (train_datasets, test_datasets), classes_per_task, (ember_test, labels_per_task) = get_malware_multitask_experiment(
#         target_classes=target_classes, scenario=scenario,\
#         orig_feats_length=orig_feats_length,\
#         target_feats_length=target_feats_length, tasks=tasks
#     )


ember_train, ember_test, classes_per_task  = get_malware_multitask_experiment(
        target_classes=target_classes, scenario=scenario,\
        orig_feats_length=orig_feats_length,\
        target_feats_length=target_feats_length, tasks=tasks
    )





[46 23 74 19 61 71 36 35 92  5 87 79 64 57 80  9 84 77 49 63 10 13  1 26
 31 67 72 51 32 65 15 70 11 60  0 62 82 58 66 45 73 25 52 59 20  3 99 30
 48 97 37 24 16 94 76 83 75 81 53 17 55 14 38 40 33 44 54 27 56 41 12 21
 68 85 39  7 88 43 78  8 34 28 69 29 47 89 18 98 91 96  2 95 93 86 50 22
 90 42  6  4]
 Training data X (303331, 2381) Y (303331,)
 Test data X (33704, 2381) Y (33704,)
[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49], [50, 51, 52, 53, 54], [55, 56, 57, 58, 59], [60, 61, 62, 63, 64], [65, 66, 67, 68, 69], [70, 71, 72, 73, 74], [75, 76, 77, 78, 79], [80, 81, 82, 83, 84], [85, 86, 87, 88, 89], [90, 91, 92, 93, 94], [95, 96, 97, 98, 99]]


In [None]:
def get_partial_data(X, Y, replay_portion):
    indx = [i for i in range(len(Y))]
    random.shuffle(indx)

    replay_data_size = int(len(indx)*replay_portion)
    replay_index = indx[:replay_data_size]

    X_train = X[replay_index]
    Y_train = Y[replay_index]
    
    return X_train, Y_train




def get_class_grs(PreviousTasksData, PreviousTasksLabels,\
                  replay_portion=0.5):
    #replay_portion = 0.5
    
    X, Y = PreviousTasksData
    
    all_replay_X = []
    all_replay_Y = []
    for previousTask, CurrentTaskLabels in enumerate(PreviousTasksLabels):
        for task_Y in CurrentTaskLabels:
            Y_task_ind = np.where(Y == task_Y)

            task_samples = X[Y_task_ind]
            task_labels = Y[Y_task_ind]

            for ind, l in enumerate(task_labels):
                all_replay_X.append(task_samples[ind])
                all_replay_Y.append(l)


    all_replay_X, all_replay_Y = np.array(all_replay_X), np.array(all_replay_Y)
    unique_labels = np.unique(all_replay_Y)
    
    #print(f'all_replay_X {all_replay_X.shape} all_replay_Y {all_replay_Y.shape}')
    all_replay_X, all_replay_Y = get_partial_data(all_replay_X, all_replay_Y, replay_portion)
    #print(f'all_replay_X {all_replay_X.shape} all_replay_Y {all_replay_Y.shape}')

    #print(unique_labels)
    
    return all_replay_X, all_replay_Y

def get_current_task_data(CurrentTaskData, CurrentTaskLabels):
    
    X, Y = CurrentTaskData
    
    X_task_samples = []
    Y_task_labels = []
    
    for task_Y in CurrentTaskLabels:
        Y_task_ind = np.where(Y == task_Y)
        
        task_samples = X[Y_task_ind]
        task_labels = Y[Y_task_ind]
        
        for ind, l in enumerate(task_labels):
            X_task_samples.append(task_samples[ind])
            Y_task_labels.append(l)
        
    X_task_samples, Y_task_labels = np.array(X_task_samples),\
                                    np.array(Y_task_labels)
    
    return X_task_samples, Y_task_labels



class malwareTrainDataset(Dataset):
    
    def __init__(self, dataset):
        super().__init__()
        self.samples, self.labels = dataset

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        
        sample = self.samples[index]
        target = self.labels[index]
        
        return (sample, target) 
    

    
diversity_mode = 'grs' #'ifs' #ahrs 


first_task = list(range(50)) #[0, 1, .., 49]

labels_per_task = list([first_task]) + [ 
    list(np.array(range(classes_per_task)) + classes_per_task * task_id) for task_id in range(10,20)]

#print(labels_per_task)
    
# ember_train, ember_test    
    
# all_current_X, all_current_Y = get_current_task_data(ember_train, labels_per_task[4])

all_replay_X, all_replay_Y = get_class_grs(ember_train, labels_per_task[2:3])


# all_X = np.concatenate((all_current_X, all_replay_X))
# all_Y = np.concatenate((all_current_Y, all_replay_Y))

# print(np.unique(all_Y))

# task_train_dataset = malwareTrainDataset((all_X, all_Y))


# cnt = 0
# labels_ = []
# for a, b in task_train_dataset:
    
#     cnt += 1
    
#     #print(a)
#     labels_.append(b)
# #     if cnt == 10:
# #         break
# print(np.unique(labels_))

In [None]:
np.unique(all_replay_Y)

In [None]:
def get_partial_data(X, Y, replay_portion):
    indx = [i for i in range(len(Y))]
    random.shuffle(indx)

    replay_data_size = int(len(indx)*replay_portion)
    replay_index = indx[:replay_data_size]

    X_train = X[replay_index]
    Y_train = Y[replay_index]
    
    return X_train, Y_train

def get_class_grs(PreviousTasksData, PreviousTasksLabels,\
                  replay_portion=0.5):
    #replay_portion = 0.5
    
    X, Y = PreviousTasksData
    
    all_replay_X = []
    all_replay_Y = []
    for previousTask, CurrentTaskLabels in enumerate(PreviousTasksLabels):
        for task_Y in CurrentTaskLabels:
            Y_task_ind = np.where(Y == task_Y)

            task_samples = X[Y_task_ind]
            task_labels = Y[Y_task_ind]

            for ind, l in enumerate(task_labels):
                all_replay_X.append(task_samples[ind])
                all_replay_Y.append(l)


    all_replay_X, all_replay_Y = np.array(all_replay_X), np.array(all_replay_Y)
    unique_labels = np.unique(all_replay_Y)
    
    #print(f'all_replay_X {all_replay_X.shape} all_replay_Y {all_replay_Y.shape}')
    all_replay_X, all_replay_Y = get_partial_data(all_replay_X, all_replay_Y, replay_portion)
    #print(f'all_replay_X {all_replay_X.shape} all_replay_Y {all_replay_Y.shape}')

    #print(unique_labels)
    print(f'all_replay_X {all_replay_X.shape} all_replay_Y {all_replay_Y.shape}')
    
    return all_replay_X, all_replay_Y



all_replay_X, all_replay_Y = get_class_grs(ember_train, labels_per_task[:3])


In [None]:
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate

def get_data_loader(dataset, batch_size, cuda=False, collate_fn=None, drop_last=False, augment=False):
    '''Return <DataLoader>-object for the provided <DataSet>-object [dataset].'''
    
    
    dataset_ = copy.deepcopy(dataset)
        
        
    lbls = []
    for ind, i in dataset_:
        #print(i)
        lbls.append(i)
    #print(np.unique(lbls))
    y = np.array(lbls,dtype=int)
    class_sample_count = np.array([len(np.where(y == t)[0]) for t in np.unique(y)])
    
    weight = 1. / class_sample_count
    
    #print(class_sample_count, weight)
        
    new_samples_weight = []
    unique_labels = np.unique(y)
    
    for lbl in y:
        for ul_ind, ul in enumerate(unique_labels):
            if lbl == ul:
                new_samples_weight.append(weight[ul_ind])
    samples_weight = np.array(new_samples_weight)

    #print(weight, np.unique(samples_weight), samples_weight[:5], y[:5])

    samples_weight = torch.from_numpy(samples_weight).float()
    sampler = torch.utils.data.WeightedRandomSampler(samples_weight, len(samples_weight), replacement=True)
        
        
    # Create and return the <DataLoader>-object
    return DataLoader(
        dataset_, batch_size=batch_size, shuffle=False,
        collate_fn=(collate_fn or default_collate), drop_last=drop_last, sampler=sampler,
        **({'num_workers': 0, 'pin_memory': True} if cuda else {})
    )

all_current_X, all_current_Y = get_current_task_data(ember_train, labels_per_task[0])
task_train_dataset = malwareTrainDataset((all_current_X, all_current_Y))
data_loader = iter(get_data_loader(task_train_dataset, 128, drop_last=True))

In [None]:


first_task = list(range(50)) #[0, 1, .., 49]

labels_per_task = list([first_task]) + [ 
    list(np.array(range(classes_per_task)) + classes_per_task * task_id) for task_id in range(10,20)]

for task, per_task in enumerate(labels_per_task, 1):
    if task != 1:
        print(f'replay {task-2} previous tasks, {labels_per_task[task-2:task-1]}')
        print(f'current task {task} {labels_per_task[task-1]}')
        print()
    else:
        print(task, labels_per_task[task-1])
        print()

In [None]:
next(data_loader)

In [None]:



first_task = list(range(50)) #[0, 1, .., 49]

labels_per_task = list([first_task]) + [ 
    list(np.array(range(classes_per_task)) + classes_per_task * task_id) for task_id in range(10,20)]

#print(labels_per_task)
    
# ember_train, ember_test    
    
all_current_X, all_current_Y = get_current_task_data(ember_train, labels_per_task[4])



In [None]:
all_current_X.shape, all_current_Y.shape, np.unique(all_current_Y)==labels_per_task[4]

In [None]:
def get_partial_data(X, Y, replay_portion):
    indx = [i for i in range(len(Y))]
    random.shuffle(indx)

    replay_data_size = int(len(indx)*replay_portion)
    replay_index = indx[:replay_data_size]

    X_train = X[replay_index]
    Y_train = Y[replay_index]
    
    return X_train, Y_train


num_samples_per_class = 200

diversity_mode = 'random' #'ifs' #ahrs 
#family_based == None




for task, train_data in enumerate(train_datasets, 1):
    print(f'task {task}')
    cnt = 0
    #print(task, train_data)
    Ys = []
    Xs = []
    for x, y in train_data:
        #print(type(xy))
        Ys.append(y)
        Xs.append(x)
        
    #print(len(np.unique(Ys)), len(Xs))
    
    Xs, Ys = np.array(Xs), np.array(Ys)
    unique_labels = np.unique(Ys)
    #print(type(Ys), type(Xs))
#     if diversity_mode == 'random' and not family_based:
#         for dY in unique_labels:
#             dY_ind = np.where(Ys == dY)
#             #print(len(dYs[0]))
#             #print(Ys)
#             #print(type(dY_ind), dY_ind)
#             dYs = Ys[dY_ind]
#             dXs = Xs[dY_ind]
#             #print(len(dYs) == len(dXs))
            
            
            
#         pass
    
    #print(unique_labels)
    
    
#         if cnt == 10:
#             break
#         cnt += 1

In [None]:
from sklearn.ensemble import IsolationForest
import random

def get_IFBased_samples(family_name, family_data,\
                        contamination,\
                        num_samples_per_malware_family):
    
    data_X = np.array(family_data)
    
    if len(data_X) > num_samples_per_malware_family:
        
        # fit the model
        clf = IsolationForest(max_samples=len(data_X), contamination=contamination)
        clf.fit(data_X)
        #scores_prediction = clf.decision_function(data_X)
        y_pred = clf.predict(data_X)


        anomalous_idx = np.where(y_pred == -1.0)
        similar_idx = np.where(y_pred == 1.0)

        #print(f'{family_name}: all-{len(y_pred)} anomalous-{len(anomalous_idx[0])} similar-{len(similar_idx[0])}')
        assert len(anomalous_idx[0]) + len(similar_idx[0]) == len(y_pred)

        anomalous_samples = data_X[anomalous_idx]
        
        if len(anomalous_samples) >= num_samples_per_malware_family:
            anomalous_samples_pool = list(anomalous_samples)
            remaining_samples_to_pick = int(num_samples_per_malware_family/2)
            anomalous_samples = random.sample(anomalous_samples_pool, remaining_samples_to_pick)

        else:
            remaining_samples_to_pick = num_samples_per_malware_family - len(anomalous_samples)
            
        
        if remaining_samples_to_pick <= len(similar_idx):
            similar_samples = data_X[similar_idx]
        else:
            similar_samples_pool = list(data_X[similar_idx])
            
            print(f'similar_samples_pool {len(similar_samples_pool)} remaining_samples_to_pick {remaining_samples_to_pick}')
            similar_samples = random.sample(similar_samples_pool, remaining_samples_to_pick)
            
        print(f'anomalous_samples {len(anomalous_samples)} similar_samples {len(similar_samples)}')
        replay_samples = np.concatenate((anomalous_samples, similar_samples))
    else:
        replay_samples = data_X
        
    print(f'Num replay samples {len(replay_samples)}')
    return replay_samples



def get_class_grs(PreviousTasksData, PreviousTasksLabels, replay_config = 'ifs'):
    #replay_portion = 0.5
    
    replay_portion = 0.5 #args.replay_portion
    
    
    X, Y = PreviousTasksData
    
    if replay_config == 'frs':
        all_replay_X = []
        all_replay_Y = []
        for previousTask, CurrentTaskLabels in enumerate(PreviousTasksLabels):
            for task_Y in CurrentTaskLabels:
                Y_task_ind = np.where(Y == task_Y)

                task_samples = X[Y_task_ind]
                task_labels = Y[Y_task_ind]
                
                for ind, l in enumerate(task_labels):
                    all_replay_X.append(task_samples[ind])
                    all_replay_Y.append(l)


        all_replay_X, all_replay_Y = np.array(all_replay_X), np.array(all_replay_Y)
        unique_labels = np.unique(all_replay_Y)
        
    elif replay_config == 'ifs':
        all_replay_X = []
        all_replay_Y = []
        for previousTask, CurrentTaskLabels in enumerate(PreviousTasksLabels):
            for task_Y in CurrentTaskLabels:
                Y_task_ind = np.where(Y == task_Y)

                task_samples = X[Y_task_ind]
                task_labels = Y[Y_task_ind]
                
                task_samples = get_IFBased_samples(task_Y, task_samples,\
                        0.1,\
                        200)
                
                
                for ind, ifs_sample in enumerate(task_samples):
                    all_replay_X.append(ifs_sample)
                    all_replay_Y.append(task_Y)


        all_replay_X, all_replay_Y = np.array(all_replay_X), np.array(all_replay_Y)
        unique_labels = np.unique(all_replay_Y)
        
        return all_replay_X, all_replay_Y
    else:
        all_replay_X = []
        all_replay_Y = []
        for previousTask, CurrentTaskLabels in enumerate(PreviousTasksLabels):
            for task_Y in CurrentTaskLabels:
                Y_task_ind = np.where(Y == task_Y)

                task_samples = X[Y_task_ind]
                task_labels = Y[Y_task_ind]

                for ind, l in enumerate(task_labels):
                    all_replay_X.append(task_samples[ind])
                    all_replay_Y.append(l)


        all_replay_X, all_replay_Y = np.array(all_replay_X), np.array(all_replay_Y)
        unique_labels = np.unique(all_replay_Y)
    
        if replay_portion == 1.0:
            return all_replay_X, all_replay_Y
        else:
            all_replay_X, all_replay_Y = get_partial_data(all_replay_X, all_replay_Y, replay_portion)
            #print(f'all_replay_X {all_replay_X.shape} all_replay_Y {all_replay_Y.shape}')
            #print(unique_labels)
            return all_replay_X, all_replay_Y
        

        
def get_current_task_test_data(X, Y, CurrentTaskLabels):
    
    X_task_samples = []
    Y_task_labels = []
    
    for task_Y in CurrentTaskLabels:
        Y_task_ind = np.where(Y == task_Y)
        
        task_samples = X[Y_task_ind]
        task_labels = Y[Y_task_ind]
        
        for ind, l in enumerate(task_labels):
            X_task_samples.append(task_samples[ind])
            Y_task_labels.append(l)
        
    X_task_samples, Y_task_labels = np.array(X_task_samples),\
                                    np.array(Y_task_labels)
    
    return X_task_samples, Y_task_label        
        
init_classes = 50
target_classes = 100
num_class = target_classes

scenario = 'class'

if scenario == 'class':
    initial_task_num_classes = init_classes
    if initial_task_num_classes > target_classes:
        raise ValueError(f"Initial Number of Classes cannot be more than {target_classes} classes!")
    left_tasks = 11 - 1 
    classes_per_task_except_first_task = int((num_class - initial_task_num_classes) / left_tasks)

    #print(selected_classes)
    first_task = list(range(initial_task_num_classes))

    labels_per_task = [first_task] + [list(initial_task_num_classes +\
                                       np.array(range(classes_per_task_except_first_task)) +\
                                       classes_per_task_except_first_task * task_id)\
                                      for task_id in range(left_tasks)]
    classes_per_task = classes_per_task_except_first_task

else:
    classes_per_task = int(np.floor(num_class / tasks))

    labels_per_task = [list(np.array(range(classes_per_task)) +\
                        classes_per_task * task_id) for task_id in range(tasks)]    

labels_per_task = labels_per_task[:2]
# Loop over all tasks.
for task, per_task in enumerate(labels_per_task, 1):

    if task != 1:
            all_replay_X, all_replay_Y = get_class_grs(ember_train, labels_per_task[:task-1])
            all_current_X, all_current_Y = get_current_task_data(ember_train, labels_per_task[task-1])
            print(f'all_current_X {all_current_X.shape} all_replay_X {all_replay_X.shape}')
            all_X = np.concatenate((all_current_X, all_replay_X))
            all_Y = np.concatenate((all_current_Y, all_replay_Y))
    else:
        all_X, all_Y = get_current_task_data(ember_train, labels_per_task[task-1])
        
                
        

In [None]:
data_dir="../../../ember2018/top_class_bases/top_classes_100/"
XY_train = np.load(data_dir + 'XY_train.npz')
Y_tr = XY_train['Y_train']



In [None]:
Ys = np.unique(Y_tr)

numS = []

for Yi in Ys:
    YiN = len(np.where(Y_tr == Yi)[0])
    numS.append(YiN)


In [None]:
#numS

In [None]:
len(np.where(np.array(numS) >= 1000)[0])

In [None]:
def get_current_task_rest_data(X, Y, CurrentTaskLabels):
    
    X_task_samples = []
    Y_task_labels = []
    
    for task_Y in CurrentTaskLabels:
        Y_task_ind = np.where(Y == task_Y)
        
        task_samples = X[Y_task_ind]
        task_labels = Y[Y_task_ind]
        
        for ind, l in enumerate(task_labels):
            X_task_samples.append(task_samples[ind])
            Y_task_labels.append(l)
        
    X_task_samples, Y_task_labels = np.array(X_task_samples),\
                                    np.array(Y_task_labels)
    
    return X_task_samples, Y_task_labels 

def get_rest_task_data(RestTasksData, RestTasksLabels):
    
    X, Y = RestTasksData
    
    all_rest_X = []
    all_rest_Y = []
    for restTask, restCurrentTaskLabels in enumerate(RestTasksLabels):
        for task_Y in restCurrentTaskLabels:
            Y_task_ind = np.where(Y == task_Y)

            task_samples = X[Y_task_ind]
            task_labels = Y[Y_task_ind]

            for ind, l in enumerate(task_labels):
                all_rest_X.append(task_samples[ind])
                all_rest_Y.append(l)


    all_rest_X, all_rest_Y = np.array(all_rest_X), np.array(all_rest_Y)
    #unique_labels = np.unique(all_rest_Y)

    return all_rest_X, all_rest_Y


tasks = 20
classes_per_task = int(np.floor(num_class / tasks))

labels_per_task = [list(np.array(range(classes_per_task)) +\
                    classes_per_task * task_id) for task_id in range(tasks)] 

#print(labels_per_task)

replay_mode ="offline"
scenario = "task"
#task = 1

all_replay_X, all_replay_Y = [], []

for task in range(1,6):
    if replay_mode=="offline" and scenario == "task":
        print(f'current task {task}')
        
        if task == 1:
            current_X, current_Y = get_current_task_data(ember_train, labels_per_task[task-1])

            print(f'current task labels {labels_per_task[task-1]}')
            print(f'rest task labels {labels_per_task[task:]}')
            rest_X, rest_Y = get_rest_task_data(ember_train, labels_per_task[task:])
            all_rest_X = rest_X
            all_rest_Y = rest_Y

        else:
            current_X, current_Y = get_current_task_data(ember_train, labels_per_task[task-1])

            prev_replay_X, prev_replay_Y = get_class_grs(ember_train, labels_per_task[task-2:task-1])
            print(f'prev_replay_Y {prev_replay_Y.shape} {np.unique(prev_replay_Y)}')
            
            if task > 2:
                all_replay_X, all_replay_Y = np.concatenate((all_replay_X, prev_replay_X)),\
                                                     np.concatenate((all_replay_Y, prev_replay_Y))
            else:
                all_replay_X, all_replay_Y = prev_replay_X, prev_replay_Y
                
            print(f'all_replay_Y {np.unique(all_replay_Y)}')
            
            
            if task != tasks:
                rest_X, rest_Y = get_rest_task_data(ember_train, labels_per_task[task:])

                all_rest_X = np.concatenate((all_replay_X, rest_X))
                all_rest_Y = np.concatenate((all_replay_Y, rest_Y))

            else:
                all_rest_X = all_replay_X
                all_rest_Y = all_replay_Y

        print()
        print(f'\n current_Y {len(current_Y)}')
        print(f'\n all_rest_Y {len(all_rest_Y)}\n')
    #     standard_scaler = standardization.partial_fit(current_X)

    #     current_X = standard_scaler.transform(current_X)
    #     all_rest_X = standard_scaler.transform(all_rest_X)

        x_test, y_test = ember_test
    #     x_test = standard_scaler.transform(x_test)


        train_datasets = [None]*tasks
        for ct, labels in enumerate(labels_per_task):
            if ct == task:
                train_dataset = malwareTrainDataset((current_X, current_Y))
                training_dataset = train_dataset
                train_datasets[ct] = train_dataset
            else:
                rest_task_X, rest_task_Y = get_current_task_rest_data(all_rest_X, all_rest_Y, labels)
                train_datasets[ct] = malwareTrainDataset((rest_task_X, rest_task_Y))


    previous_datasets = train_datasets
    iter(get_data_loader(previous_datasets[task], 512, drop_last=True))

In [None]:
iter(get_data_loader(previous_datasets[task], 512, drop_last=True))

In [None]:
np.unique(all_rest_Y), np.unique(prev_replay_Y)

In [None]:
get_data_loader(previous_datasets[1], 512, drop_last=True)

In [None]:
len(previous_datasets[1])

In [None]:
batch_size = 512

if scenario=="task":

    up_to_task = task
    
    batch_size_replay = int(np.ceil(batch_size/up_to_task)) if (up_to_task>1) else batch_size
    
    # -in Task-IL scenario, need separate replay for each task
    for task_id in range(up_to_task):
        batch_size_to_use = min(batch_size_replay, len(previous_datasets[task_id]))
        
        iters_left_previous[task_id] -= 1
        
        if iters_left_previous[task_id] == 0:
            data_loader_previous[task_id] = iter(utils.get_data_loader(
                previous_datasets[task_id], batch_size_to_use, cuda=cuda, drop_last=True
            ))
            iters_left_previous[task_id] = len(data_loader_previous[task_id])