In [3]:
import sys
import numpy as np
import torch
from pprint import pprint

In [134]:
tensor = torch.tensor([i//50 for i in range(100)])

buffer_size = 100
task_id = 1 # python's off by one thing
cpt = 2

task_partitions = tensor.chunk(task_id)
class_partitions = [partition for task_partition in task_partitions for partition in task_partition.chunk(cpt)] # a list of tensors partioned along new class lines

buffer=[]
# First we delete
for i, class_partition in enumerate(class_partitions):
    # number of samples per class at current task partition
    samples_per_class = buffer_size // (task_id+1) // cpt
    
    # number of samples to be deleted from each class
    x = np.abs(samples_per_class - class_partition.shape[0])
    class_partition = class_partition[:-x]
    class_partitions[i] = class_partition
    
# Second we add in new samples
for i in range(2,4):
    # Add in new class samples to buffer (add in ipm samples, ipm samples for new task calculated external of this overwrite function)
    new_samples = i*torch.ones(samples_per_class)
    class_partitions.append(new_samples)
tensor = torch.cat(class_partitions)

### Initial Attempt

In [41]:
# Make maximal use of Buffer at all time
# Not perfect as sometimes buffer will empty itself if numbers dont work out

def overwrite(buffer, buffer_size, task_id, cpt):
    
    task_partitions = buffer.chunk(task_id)
    class_partitions = [partition for task_partition in task_partitions for partition in task_partition.chunk(cpt)] # a list of tensors partioned along new class lines
    
    # First we delete
    for i, class_partition in enumerate(class_partitions):
        # number of samples per class at current task partition
        #samples_per_class = int(np.ceil(buffer_size / (task_id+1) / cpt))
        if i % 2 == 0:
            samples_per_class = int(np.floor(buffer_size / (task_id+1) / cpt))
        else:
            samples_per_class = int(np.ceil(buffer_size / (task_id+1) / cpt))
        #print(samples_per_class)

        # number of samples to be deleted from each class
        x = np.abs(samples_per_class - class_partition.shape[0])
        #print(x)
        class_partition = class_partition[:-x]
        class_partitions[i] = class_partition
        
    # Second we add in new samples
    for i in range(task_id*cpt, task_id*cpt + cpt):
        # Add in new class samples to buffer (add in ipm samples, ipm samples for new task calculated external of this overwrite function)
        new_samples = i*torch.ones(samples_per_class)
        class_partitions.append(new_samples)
        
    buffer = torch.cat(class_partitions)
    return buffer

In [57]:
buffer_size = 100
n_tasks = 5
cpt = 2

for t in range(n_tasks):
    if t == 0:
        buffer = torch.tensor([i//50 for i in range(buffer_size)])
        print('Examples in buffer: {}'.format(torch.unique(buffer, return_counts = True)[1].sum()))
        print(torch.unique(buffer, return_counts = True))
    else:
        buffer = overwrite(buffer, buffer_size, task_id=t, cpt=cpt)
        print('Examples in buffer: {}'.format(torch.unique(buffer, return_counts = True)[1].sum()))
        print(torch.unique(buffer, return_counts = True))

Examples in buffer: 100
(tensor([0, 1]), tensor([50, 50]))
Examples in buffer: 100
(tensor([0., 1., 2., 3.]), tensor([25, 25, 25, 25]))
Examples in buffer: 100
(tensor([0., 1., 2., 3., 4., 5.]), tensor([16, 17, 16, 17, 17, 17]))
Examples in buffer: 101
(tensor([0., 1., 2., 3., 4., 5., 6., 7.]), tensor([12, 13, 12, 13, 12, 13, 13, 13]))
Examples in buffer: 100
(tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), tensor([10, 10, 10, 10, 10, 10, 10, 10, 10, 10]))


### Second Attempt

In [114]:
buffer_size = 100
labels = torch.cat([torch.zeros(16), torch.ones(17), 2*torch.ones(16), 3*torch.ones(17), 4*torch.ones(17), 5*torch.ones(17)], dim=0)
data = torch.randn(100)
task_id = 3 # task_id + 1 for python's off by one thing
cpt = 2

class_ids = torch.unique(labels)
data_partitions = [data[torch.argwhere(labels == class_id)].squeeze() for class_id in class_ids]
label_partitions = [labels[torch.argwhere(labels == class_id)].squeeze() for class_id in class_ids]

# Delete Data
for i in range(len(class_ids)):
    
    if i % 2 == 0:
        samples_per_class = int(np.floor(buffer_size / (task_id+1) / cpt))
    else:
        samples_per_class = int(np.ceil(buffer_size / (task_id+1) / cpt))
        
    x = np.abs(samples_per_class - label_partitions[i].shape[0])
    
    data_partitions[i] = data_partitions[i][:-x]
    label_partitions[i] = label_partitions[i][:-x]
    
data_buffer = torch.cat(data_partitions)
label_buffer = torch.cat(label_partitions)

# Add Data
num_to_be_added = buffer_size - data_buffer.shape[0]

for i in range(cpt*task_id, cpt*task_id + 2):
    
    if i % 2 == 0:
        samples_to_add = int(np.floor(num_to_be_added / cpt))
    else:
        samples_to_add = int(np.ceil(num_to_be_added / cpt))
        
    data_buffer = torch.cat((data_buffer, torch.randn(samples_to_add)))
    label_buffer = torch.cat((label_buffer, i*torch.ones(samples_to_add)))

#### Below only works for 5 tasks with 2 classes per task

In [117]:
# data, labels, and buffer_size all exist relative to buffer object (i.e., self.examples, self.labels, self.buffer_size)
# we could also make cpt relative to buffer as well since this typically doesn't change in literature

def delete_data(data_buffer, labels_buffer, buffer_size, task_id, cpt):
    
    class_ids = torch.unique(labels_buffer)
    data_partitions = [data_buffer[torch.argwhere(labels_buffer == class_id)].squeeze() for class_id in class_ids]
    label_partitions = [labels_buffer[torch.argwhere(labels_buffer == class_id)].squeeze() for class_id in class_ids]
    
    #if task_id == 7:
    #    pprint(class_ids)
    #    #pprint(label_partitions)
    #    sys.exit()

    # Delete Data
    for i in range(len(class_ids)):

        if i % 2 == 0:
            samples_per_class = int(np.floor(buffer_size / (task_id+1) / cpt))
        else:
            samples_per_class = int(np.ceil(buffer_size / (task_id+1) / cpt))

        x = np.abs(samples_per_class - label_partitions[i].shape[0])
        
        if t == 7:
            print('\nClass:', i)
            print('Difference:', x)
            print('Current samples per class:', samples_per_class)
            print('Previous samples per class:', label_partitions[i].shape[0])

        data_partitions[i] = data_partitions[i][:-x]
        label_partitions[i] = label_partitions[i][:-x]
        
    data_buffer = torch.cat(data_partitions)
    labels_buffer = torch.cat(label_partitions)
    
    assert data_buffer.shape[0] == labels_buffer.shape[0]
    return data_buffer, labels_buffer

def add_data(data_buffer, labels_buffer, buffer_size, task_id, cpt):
    
    num_to_be_added = buffer_size - data_buffer.shape[0]
    
    for i in range(cpt*task_id, cpt*task_id + cpt):
        
        if i % 2 == 0:
            samples_to_add = int(np.ceil(num_to_be_added / cpt))
        else:
            samples_to_add = int(np.floor(num_to_be_added / cpt))
            
        #if t == 6:
        #    print('\nClass:', i)
        #    print('Difference:', x)
        #    print('Current samples per class:', samples_per_class)
        #    print('Previous samples per class:', label_partitions[i].shape[0])  

        data_buffer = torch.cat((data_buffer, torch.randn(samples_to_add)))
        labels_buffer = torch.cat((labels_buffer, i*torch.ones(samples_to_add)))
    
    assert data_buffer.shape[0] == labels_buffer.shape[0]
    return data_buffer, labels_buffer

In [132]:
buffer_size = 2560
n_tasks = 5
cpt = 2

for t in range(n_tasks):
    
    #print('\nTask {}'.format(t))
    
    if t == 0:
        labels_buffer = torch.tensor([i//(buffer_size/cpt) for i in range(buffer_size)]) # change denom when changing buffer_size and cpt
        data_buffer = torch.randn(buffer_size)
        print('Examples in buffer: {}'.format(torch.unique(labels_buffer, return_counts = True)[1].sum()))
        unique, counts = torch.unique(labels_buffer, return_counts = True)
        #print(unique)
        #print(counts)
        
    else:
        data_buffer, labels_buffer = delete_data(data_buffer, labels_buffer, buffer_size, t, cpt)
        #if t == 7:
        #    sys.exit()
        data_buffer, labels_buffer = add_data(data_buffer, labels_buffer, buffer_size, t, cpt)
        print('Examples in buffer: {}'.format(torch.unique(labels_buffer, return_counts = True)[1].sum()))
        unique, counts = torch.unique(labels_buffer, return_counts = True)
        #print(unique)
        #print(counts)

Examples in buffer: 2560
Examples in buffer: 2560
Examples in buffer: 2560
Examples in buffer: 2560
Examples in buffer: 2560


### Work to Generalize to N tasks with m cpt

In [9]:
# data, labels, and buffer_size all exist relative to buffer object (i.e., self.examples, self.labels, self.buffer_size)
# we could also make cpt relative to buffer as well since this typically doesn't change in literature

def delete_data(data_buffer, labels_buffer, buffer_size, task_id, cpt):
    
    class_ids = torch.unique(labels_buffer)
    data_partitions = [data_buffer[torch.argwhere(labels_buffer == class_id)].squeeze() for class_id in class_ids]
    label_partitions = [labels_buffer[torch.argwhere(labels_buffer == class_id)].squeeze() for class_id in class_ids]

    # Delete Data
    #counter = 0
    for i in range(len(class_ids)):

        if i % 2 == 0:
            samples_per_class = int(np.floor(buffer_size / (task_id+1) / cpt))
        else:
            samples_per_class = int(np.ceil(buffer_size / (task_id+1) / cpt))
        
        # Edge Cases for when when memory buget runs out and each class only has one sample
        if label_partitions[i].shape == torch.Size([]):
            x = 0
            label_partitions[i] = label_partitions[i].unsqueeze(0)
        elif samples_per_class == 2:
            x = 1
        else:
            # Normal situation when edge cases are not an issue
            x = np.abs(samples_per_class - label_partitions[i].shape[0])
            
        # Another edge case scenario
        if data_partitions[i].shape == torch.Size([]):
            data_partitions[i] = data_partitions[i].unsqueeze(0)
        
        # x = 0 corresponds to nothing to delete when creates empty class arrays for storage
        if x != 0:
            data_partitions[i] = data_partitions[i][:-x]
            label_partitions[i] = label_partitions[i][:-x]
        else:
            pass

    data_buffer = torch.cat(data_partitions)
    labels_buffer = torch.cat(label_partitions)
    
    assert data_buffer.shape[0] == labels_buffer.shape[0]
    return data_buffer, labels_buffer

def add_data(data_buffer, labels_buffer, buffer_size, task_id, cpt):
    
    num_to_be_added = buffer_size - data_buffer.shape[0]
    print('Added {} samples to buffer'.format(num_to_be_added))
    
    for i in range(cpt*task_id, cpt*task_id + cpt):
        
        if i % 2 == 0:
            samples_to_add = int(np.floor(num_to_be_added / cpt))
        else:
            samples_to_add = int(np.ceil(num_to_be_added / cpt))

        data_buffer = torch.cat((data_buffer, torch.randn(samples_to_add)))
        labels_buffer = torch.cat((labels_buffer, i*torch.ones(samples_to_add)))
    
    assert data_buffer.shape[0] == labels_buffer.shape[0]
    return data_buffer, labels_buffer

In [10]:
buffer_size = 200
n_tasks = 10
cpt = 20

for t in range(n_tasks):
    
    print('\nTask {}'.format(t))
    
    if t == 0:
        labels_buffer = torch.tensor([i//(buffer_size/cpt) for i in range(buffer_size)]) # change denom when changing buffer_size and cpt
        data_buffer = torch.randn(buffer_size)
        print('Examples in buffer: {}'.format(torch.unique(labels_buffer, return_counts = True)[1].sum()))
        unique, counts = torch.unique(labels_buffer, return_counts = True)
        print(unique)
        print(counts)
        
    else:
        data_buffer, labels_buffer = delete_data(data_buffer, labels_buffer, buffer_size, t, cpt)
        data_buffer, labels_buffer = add_data(data_buffer, labels_buffer, buffer_size, t, cpt)
        print('Examples in buffer: {}'.format(torch.unique(labels_buffer, return_counts = True)[1].sum()))
        unique, counts = torch.unique(labels_buffer, return_counts = True)
        print(unique)
        print(counts)


Task 0
Examples in buffer: 200
tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
        14., 15., 16., 17., 18., 19.])
tensor([10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
        10, 10])

Task 1
Added 100 samples to buffer
Examples in buffer: 200
tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
        14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27.,
        28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39.])
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5])

Task 2
Added 60 samples to buffer
Examples in buffer: 200
tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
        14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27.,
        28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39., 40., 41.,
        42., 43., 44., 45., 46., 47., 4

#### If decide to run TinyImageNet Experiments with a buffer size of 100, we will have to further adapt to store more recently encountered examples as opposed to original examples. Although, an arguement can be that that we should keep the older samples since this are the samples most likely to be forgotten as compared to the most recent sampels