In [2]:
import torch
import numpy as np
from torchmetrics.classification import BinaryAccuracy
import sklearn.metrics

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
class XBM:
    def __init__(self):
        self.K = 20
        self.feats = torch.zeros(self.K, 2).cuda()
        self.targets = torch.zeros(self.K, dtype=torch.long).cuda()
        self.targets[:]=-1
        self.ptr = 0

    @property
    def is_full(self):
        return self.targets[-1].item() != -1

    def get(self):
        if self.is_full:
            return self.feats, self.targets
        else:
            return self.feats[:self.ptr], self.targets[:self.ptr]

    def enqueue_dequeue(self, feats, targets):
        q_size = len(targets)
        print(q_size)
        if self.ptr + q_size > self.K:
            self.feats[-q_size:] = feats
            self.targets[-q_size:] = targets
            self.ptr = 0
        else:
            self.feats[self.ptr: self.ptr + q_size] = feats
            self.targets[self.ptr: self.ptr + q_size] = targets
            self.ptr += q_size

In [11]:
xbm = XBM()

In [12]:
xbm.enqueue_dequeue(torch.rand(10, 2), torch.tensor([1,2,3,4,5,6,7,8,9,10]))
print(xbm.targets)
xbm.enqueue_dequeue(torch.rand(10, 2), torch.tensor([1,2,3,4,5,6,7,8,9,10]))
print(xbm.targets)
xbm.enqueue_dequeue(torch.rand(10, 2), torch.tensor([0,0,0,0,0,0,0,0,0,0]))
print(xbm.targets)
xbm.enqueue_dequeue(torch.rand(10, 2), torch.tensor([13,13,13,13,13,13,13,13,13,13]))
print(xbm.targets)

10
tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1], device='cuda:0')
10
tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10,  1,  2,  3,  4,  5,  6,  7,  8,
         9, 10], device='cuda:0')
10
tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0], device='cuda:0')
10
tensor([13, 13, 13, 13, 13, 13, 13, 13, 13, 13,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0], device='cuda:0')


In [15]:
xbm.enqueue_dequeue(torch.rand(10, 2), torch.tensor([1,2,3,4,5,6,7,8,9,10]))
print(xbm.targets)


10
tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10,  1,  2,  3,  4,  5,  6,  7,  8,
         9, 10], device='cuda:0')


In [46]:
xbm.enqueue_dequeue(torch.rand(1, 2), torch.tensor([91]))
print(xbm.targets)

1
tensor([99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 91, 91, 91, 91, 91, 91, 91, 91,
        91, 91], device='cuda:0')
