Skip to content

Commit

Permalink
test(tuner): add test for overfitting (#109)
Browse files Browse the repository at this point in the history
* test(tuner): add test for overfitting

* fix: apply suggestions from code review

Co-authored-by: Wang Bo <bo.wang@jina.ai>

* fix(tuner): dim

Co-authored-by: Wang Bo <bo.wang@jina.ai>
  • Loading branch information
Tadej Svetina and bwanglzu committed Oct 11, 2021
1 parent a6d16ff commit 562c65f
Show file tree
Hide file tree
Showing 5 changed files with 253 additions and 1 deletion.
3 changes: 2 additions & 1 deletion .github/requirements-cicd.txt
Expand Up @@ -2,4 +2,5 @@ numpy
tensorflow
paddlepaddle
torch
torchvision
torchvision
scipy
61 changes: 61 additions & 0 deletions tests/integration/conftest.py
@@ -1,4 +1,8 @@
import numpy as np
import pytest
from jina import Document, DocumentArray

from finetuner import __default_tag_key__


@pytest.fixture
Expand All @@ -15,3 +19,60 @@ def params():
'num_predict': 100,
'max_seq_len': 10,
}


@pytest.fixture
def create_easy_data():
def create_easy_data_fn(n_cls: int, dim: int, n_sample: int):
"""Creates a dataset from random vectors.
Works as follows:
- for each class, create two random vectors - so that each one has a positive
sample as well. This will create 2 * n_cls unique random vectors, from
which we build the dataset
- loop over the dataset (if n_sample > 2 * n_cls documents will be repeated),
and for each vector add its positive sample, and vectors from all other
classes as a negative sample. This is important, as it assures that each
vector will see all others in training
In the end you will have a dataset of size n_samples, where each item has
one positive sample and 2 * (n_cls - 1) negative samples.
Note that there is no relationship between these vectors - they are all randomly
generated. The purpose of this dataset is to verify that over-parametrized
models can properly separate (or bring together) these random vectors, thus
confirming that our training method works.
"""

# Fix random seed so we can debug on data, if needed
rng = np.random.default_rng(42)

# Create random class vectors
rand_vecs = rng.uniform(size=(2 * n_cls, dim)).astype(np.float32)

# Generate anchor-pos-neg triplets
triplets = DocumentArray()
for i in range(n_sample):
anchor_ind = i % (2 * n_cls)
pos_ind = anchor_ind - 1 if anchor_ind % 2 == 1 else anchor_ind + 1

d = Document(blob=rand_vecs[anchor_ind])
d.matches.append(
Document(
blob=rand_vecs[pos_ind], tags={__default_tag_key__: {'label': 1}}
)
)

neg_inds = [j for j in range(2 * n_cls) if j not in [anchor_ind, pos_ind]]
for neg_ind in neg_inds:
d.matches.append(
Document(
blob=rand_vecs[neg_ind],
tags={__default_tag_key__: {'label': -1}},
)
)

triplets.append(d)
return triplets, rand_vecs

return create_easy_data_fn
62 changes: 62 additions & 0 deletions tests/integration/keras/test_overfit.py
@@ -0,0 +1,62 @@
import pytest
import tensorflow as tf
from scipy.spatial.distance import pdist, squareform

from finetuner.tuner.keras import KerasTuner


@pytest.mark.parametrize(
"n_cls,dim,n_samples,n_epochs,batch_size,head_layer",
[
(5, 10, 100, 5, 25, 'TripletLayer'),
(5, 10, 1000, 15, 256, 'CosineLayer'), # Cosine needs more training to converge
],
)
def test_overfit_keras(
create_easy_data,
n_cls: int,
dim: int,
n_samples: int,
n_epochs: int,
batch_size: int,
head_layer: str,
):
"""This test makes sure that we can overfit the model to a small amount of data.
We use an over-parametrized model (a few thousand weights for <100 unique input
vectors), which should easily be able to bring vectors from same class
together, and put those from different classes apart - note that all the vectors
are random.
"""

# Prepare model and data
data, vecs = create_easy_data(n_cls, dim, n_samples)
embed_model = tf.keras.Sequential(
[
tf.keras.layers.Flatten(input_shape=(dim,)),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(32),
]
)

# Train
pt = KerasTuner(embed_model, head_layer=head_layer)
pt.fit(train_data=data, epochs=n_epochs, batch_size=batch_size)

# Compute embedding for original vectors
vec_embedings = embed_model(vecs).numpy()

# Compute distances between embeddings
metric = 'sqeuclidean' if head_layer == 'TripletLayer' else 'cosine'
dists = squareform(pdist(vec_embedings, metric=metric))

# Make sure that for each class, the two instances are closer than
# anything else
for i in range(n_cls):
cls_dist = dists[2 * i, 2 * i + 1]
dist_other = dists[2 * i : 2 * i + 2, :].copy()
dist_other[:, 2 * i : 2 * i + 2] = 10_000

assert cls_dist < dist_other.min() + 1
64 changes: 64 additions & 0 deletions tests/integration/paddle/test_overfit.py
@@ -0,0 +1,64 @@
import paddle
import pytest
from paddle import nn
from scipy.spatial.distance import pdist, squareform

from finetuner.tuner.paddle import PaddleTuner


@pytest.mark.parametrize(
"n_cls,dim,n_samples,n_epochs,batch_size,head_layer",
[
(5, 10, 100, 5, 25, 'TripletLayer'),
(5, 10, 1000, 15, 256, 'CosineLayer'), # Cosine needs more training to converge
],
)
def test_overfit_paddle(
create_easy_data,
n_cls: int,
dim: int,
n_samples: int,
n_epochs: int,
batch_size: int,
head_layer: str,
):
"""This test makes sure that we can overfit the model to a small amount of data.
We use an over-parametrized model (a few thousand weights for <100 unique input
vectors), which should easily be able to bring vectors from same class
together, and put those from different classes apart - note that all the vectors
are random.
"""

# Prepare model and data
data, vecs = create_easy_data(n_cls, dim, n_samples)
embed_model = nn.Sequential(
nn.Flatten(),
nn.Linear(in_features=dim, out_features=64),
nn.ReLU(),
nn.Linear(in_features=64, out_features=64),
nn.ReLU(),
nn.Linear(in_features=64, out_features=64),
nn.ReLU(),
nn.Linear(in_features=64, out_features=32),
)

# Train
pt = PaddleTuner(embed_model, head_layer=head_layer)
pt.fit(train_data=data, epochs=n_epochs, batch_size=batch_size)

# Compute embedding for original vectors
vec_embedings = embed_model(paddle.Tensor(vecs)).numpy()

# Compute distances between embeddings
metric = 'sqeuclidean' if head_layer == 'TripletLayer' else 'cosine'
dists = squareform(pdist(vec_embedings, metric=metric))

# Make sure that for each class, the two instances are closer than
# anything else
for i in range(n_cls):
cls_dist = dists[2 * i, 2 * i + 1]
dist_other = dists[2 * i : 2 * i + 2, :].copy()
dist_other[:, 2 * i : 2 * i + 2] = 10_000

assert cls_dist < dist_other.min() + 1
64 changes: 64 additions & 0 deletions tests/integration/torch/test_overfit.py
@@ -0,0 +1,64 @@
import pytest
import torch
from scipy.spatial.distance import pdist, squareform

from finetuner.tuner.pytorch import PytorchTuner


@pytest.mark.parametrize(
"n_cls,dim,n_samples,n_epochs,batch_size,head_layer",
[
(5, 10, 100, 5, 25, 'TripletLayer'),
(5, 10, 1000, 15, 256, 'CosineLayer'), # Cosine needs more training to converge
],
)
def test_overfit_pytorch(
create_easy_data,
n_cls: int,
dim: int,
n_samples: int,
n_epochs: int,
batch_size: int,
head_layer: str,
):
"""This test makes sure that we can overfit the model to a small amount of data.
We use an over-parametrized model (a few thousand weights for <100 unique input
vectors), which should easily be able to bring vectors from same class
together, and put those from different classes apart - note that all the vectors
are random.
"""

# Prepare model and data
data, vecs = create_easy_data(n_cls, dim, n_samples)
embed_model = torch.nn.Sequential(
torch.nn.Flatten(),
torch.nn.Linear(in_features=dim, out_features=64),
torch.nn.ReLU(),
torch.nn.Linear(in_features=64, out_features=64),
torch.nn.ReLU(),
torch.nn.Linear(in_features=64, out_features=64),
torch.nn.ReLU(),
torch.nn.Linear(in_features=64, out_features=32),
)

# Train
pt = PytorchTuner(embed_model, head_layer=head_layer)
pt.fit(train_data=data, epochs=n_epochs, batch_size=batch_size)

# Compute embedding for original vectors
with torch.inference_mode():
vec_embedings = embed_model(torch.Tensor(vecs)).numpy()

# Compute distances between embeddings
metric = 'sqeuclidean' if head_layer == 'TripletLayer' else 'cosine'
dists = squareform(pdist(vec_embedings, metric=metric))

# Make sure that for each class, the two instances are closer than
# anything else
for i in range(n_cls):
cls_dist = dists[2 * i, 2 * i + 1]
dist_other = dists[2 * i : 2 * i + 2, :].copy()
dist_other[:, 2 * i : 2 * i + 2] = 10_000

assert cls_dist < dist_other.min() + 1

0 comments on commit 562c65f

Please sign in to comment.