Skip to content

Commit

Permalink
test: fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
maximilianwerk committed Oct 18, 2021
1 parent 10e1cc0 commit 2d6811d
Show file tree
Hide file tree
Showing 9 changed files with 35 additions and 38 deletions.
18 changes: 9 additions & 9 deletions tests/integration/keras/test_keras_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import tensorflow as tf
from tensorflow import keras

from finetuner.tuner.keras import KerasTuner
from finetuner.tuner import fit, save
from finetuner.toydata import generate_fashion_match
from finetuner.toydata import generate_qa_match

Expand All @@ -29,10 +29,10 @@ def test_simple_sequential_model(tmpdir, params, loss):
]
)

kt = KerasTuner(user_model, loss=loss)

# fit and save the checkpoint
kt.fit(
fit(
user_model,
loss=loss,
train_data=lambda: generate_fashion_match(
num_pos=10, num_neg=10, num_total=params['num_train']
),
Expand All @@ -42,7 +42,7 @@ def test_simple_sequential_model(tmpdir, params, loss):
epochs=params['epochs'],
batch_size=params['batch_size'],
)
kt.save(tmpdir / 'trained.kt')
save(user_model, tmpdir / 'trained.kt')

embedding_model = keras.models.load_model(tmpdir / 'trained.kt')
r = embedding_model.predict(
Expand All @@ -63,10 +63,10 @@ def test_simple_lstm_model(tmpdir, params, loss):
]
)

kt = KerasTuner(user_model, loss=loss)

# fit and save the checkpoint
kt.fit(
fit(
user_model,
loss=loss,
train_data=lambda: generate_qa_match(
num_total=params['num_train'],
max_seq_len=params['max_seq_len'],
Expand All @@ -82,7 +82,7 @@ def test_simple_lstm_model(tmpdir, params, loss):
epochs=params['epochs'],
batch_size=params['batch_size'],
)
kt.save(tmpdir / 'trained.kt')
save(user_model, tmpdir / 'trained.kt')

embedding_model = keras.models.load_model(tmpdir / 'trained.kt')
r = embedding_model.predict(
Expand Down
5 changes: 2 additions & 3 deletions tests/integration/keras/test_overfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import tensorflow as tf
from scipy.spatial.distance import pdist, squareform

from finetuner.tuner.keras import KerasTuner
from finetuner.tuner import fit


@pytest.mark.parametrize(
Expand Down Expand Up @@ -45,8 +45,7 @@ def test_overfit_keras(
)

# Train
pt = KerasTuner(embed_model, loss=loss)
pt.fit(train_data=data, epochs=n_epochs, batch_size=batch_size)
fit(embed_model, loss=loss, train_data=data, epochs=n_epochs, batch_size=batch_size)

# Compute embedding for original vectors
vec_embedings = embed_model(vecs).numpy()
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/keras/test_tail_and_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ def test_tail_and_tune(embed_model, create_easy_data):
output_dim=16,
layer_name='dense_2',
)
assert rv['loss']['train']
assert rv._loss_train
5 changes: 2 additions & 3 deletions tests/integration/paddle/test_overfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from paddle import nn
from scipy.spatial.distance import pdist, squareform

from finetuner.tuner.paddle import PaddleTuner
from finetuner.tuner import fit


@pytest.mark.parametrize(
Expand Down Expand Up @@ -47,8 +47,7 @@ def test_overfit_paddle(
)

# Train
pt = PaddleTuner(embed_model, loss=loss)
pt.fit(train_data=data, epochs=n_epochs, batch_size=batch_size)
fit(embed_model, loss=loss, train_data=data, epochs=n_epochs, batch_size=batch_size)

# Compute embedding for original vectors
vec_embedings = embed_model(paddle.Tensor(vecs)).numpy()
Expand Down
17 changes: 9 additions & 8 deletions tests/integration/paddle/test_paddle_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
from paddle import nn

from finetuner.tuner.paddle import PaddleTuner
from finetuner.tuner import fit, save
from finetuner.toydata import generate_fashion_match
from finetuner.toydata import generate_qa_match

Expand All @@ -28,10 +28,11 @@ def test_simple_sequential_model(tmpdir, params, loss):
nn.Linear(in_features=params['feature_dim'], out_features=params['output_dim']),
)

pt = PaddleTuner(user_model, loss=loss)
model_path = tmpdir / 'trained.pd'
# fit and save the checkpoint
pt.fit(
fit(
user_model,
loss=loss,
train_data=lambda: generate_fashion_match(
num_pos=10, num_neg=10, num_total=params['num_train']
),
Expand All @@ -42,7 +43,7 @@ def test_simple_sequential_model(tmpdir, params, loss):
batch_size=params['batch_size'],
)

pt.save(model_path)
save(user_model, model_path)

user_model.set_state_dict(paddle.load(model_path))
user_model.eval()
Expand Down Expand Up @@ -84,10 +85,10 @@ def forward(self, x):
)
model_path = tmpdir / 'trained.pd'

pt = PaddleTuner(user_model, loss=loss)

# fit and save the checkpoint
pt.fit(
fit(
user_model,
loss=loss,
train_data=lambda: generate_qa_match(
num_total=params['num_train'],
max_seq_len=params['max_seq_len'],
Expand All @@ -103,7 +104,7 @@ def forward(self, x):
epochs=params['epochs'],
batch_size=params['batch_size'],
)
pt.save(model_path)
save(user_model, model_path)

# load the checkpoint and ensure the dim
user_model.set_state_dict(paddle.load(model_path))
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/paddle/test_tail_and_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ def test_tail_and_tune(embed_model, create_easy_data):
output_dim=16,
layer_name='linear_4',
)
assert rv['loss']['train']
assert rv._loss_train
6 changes: 2 additions & 4 deletions tests/integration/torch/test_overfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from scipy.spatial.distance import pdist, squareform

from finetuner.tuner.pytorch import PytorchTuner
from finetuner.tuner import fit


@pytest.mark.parametrize(
Expand Down Expand Up @@ -44,10 +44,8 @@ def test_overfit_pytorch(
torch.nn.ReLU(),
torch.nn.Linear(in_features=64, out_features=32),
)

# Train
pt = PytorchTuner(embed_model, loss=loss)
pt.fit(train_data=data, epochs=n_epochs, batch_size=batch_size)
fit(embed_model, loss=loss, train_data=data, epochs=n_epochs, batch_size=batch_size)

# Compute embedding for original vectors
with torch.inference_mode():
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/torch/test_tail_and_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ def test_tail_and_tune(embed_model, create_easy_data):
output_dim=16,
layer_name='linear_4',
)
assert rv['loss']['train']
assert rv._loss_train
16 changes: 8 additions & 8 deletions tests/integration/torch/test_torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import torch.nn as nn

from finetuner.tuner.pytorch import PytorchTuner
from finetuner.tuner import fit, save
from finetuner.toydata import generate_fashion_match
from finetuner.toydata import generate_qa_match

Expand All @@ -31,10 +31,10 @@ def test_simple_sequential_model(tmpdir, params, loss):
)
model_path = os.path.join(tmpdir, 'trained.pth')

pt = PytorchTuner(user_model, loss=loss)

# fit and save the checkpoint
pt.fit(
user_model,
loss=loss,
train_data=lambda: generate_fashion_match(
num_pos=10, num_neg=10, num_total=params['num_train']
),
Expand All @@ -44,7 +44,7 @@ def test_simple_sequential_model(tmpdir, params, loss):
epochs=params['epochs'],
batch_size=params['batch_size'],
)
pt.save(model_path)
save(user_model, model_path)

# load the checkpoint and ensure the dim
user_model.load_state_dict(torch.load(model_path))
Expand Down Expand Up @@ -88,10 +88,10 @@ def forward(self, x):
)
model_path = os.path.join(tmpdir, 'trained.pth')

pt = PytorchTuner(user_model, loss=loss)

# fit and save the checkpoint
pt.fit(
fit(
user_model,
loss=loss,
train_data=lambda: generate_qa_match(
num_total=params['num_train'],
max_seq_len=params['max_seq_len'],
Expand All @@ -107,7 +107,7 @@ def forward(self, x):
epochs=params['epochs'],
batch_size=params['batch_size'],
)
pt.save(model_path)
save(user_model, model_path)

# load the checkpoint and ensure the dim
user_model.load_state_dict(torch.load(model_path))
Expand Down

0 comments on commit 2d6811d

Please sign in to comment.