In [1]:
import torch as th

from sklearn.datasets import make_classification
from torch.utils.data import TensorDataset, DataLoader
from torch.autograd import Variable as V

from tqdm import tqdm

from pymips.plugins.torch import ApproximateLinear

Failed to load GPU Faiss: No module named 'faiss.swigfaiss_gpu'
Faiss falling back to CPU-only.


### Create dummy dataset

In [2]:
X, Y = make_classification(n_samples=100_000, n_features=256, n_informative=256, n_redundant=0, 
                           n_classes=10_000)

X = th.from_numpy(X).float()
Y = th.from_numpy(Y).long()

### Create the model

In [43]:
def index_factory(d):
    index = faiss.index_factory(d, "IVF256,Flat", faiss.METRIC_INNER_PRODUCT)
    index.nprobe = 32

    return index

model = th.nn.Sequential(
    th.nn.Linear(256, 256),
    th.nn.ReLU(),
    ApproximateLinear(256, 10_000, index_factory=index_factory),
)

dset   = TensorDataset(X, Y)
loader = DataLoader(dset, batch_size=64)
adam   = th.optim.Adam(model.parameters())

### Train

### Evaluate full

In [44]:
preds, targets = [], []
loader = DataLoader(dset, batch_size=64, shuffle=False)

with th.autograd.no_grad():
    for x, y in tqdm(loader):
        x, y = V(x), V(y)
        o    = model(x)
        
        p = o.max(1)[1]
        
        preds.append(p)
        targets.append(y)
        
preds   = th.cat(preds)
targets = th.cat(targets)

100%|██████████| 1563/1563 [00:11<00:00, 135.32it/s]


In [45]:
acc = float((preds == targets).float().sum() / float(preds.size(0)))
print(f'{acc:.5f}')

0.00012


### Evaluate approximate

In [46]:
preds, targets = [], []
loader = DataLoader(dset, batch_size=64, shuffle=False)
model  = model.eval()

with th.autograd.no_grad():
    for x, y in tqdm(loader):
        x, y = V(x), V(y)
        _, p = model(x)
        
        preds.append(p.squeeze())
        targets.append(y.squeeze())
        
preds   = th.cat(preds)
targets = th.cat(targets)

100%|██████████| 1563/1563 [00:02<00:00, 560.65it/s]


In [47]:
acc = float((preds == targets).float().sum() / float(preds.size(0)))
print(f'{acc:.5f}')

0.00013
