Skip to content

Commit

Permalink
Workaroun mixed device issue in multiple engine case
Browse files Browse the repository at this point in the history
Change-Id: Iff9c8f0cfb6ca56d035436c6278e34d3590fbe78
  • Loading branch information
frankfliu committed Jul 27, 2021
1 parent 154ae6e commit 79fa6fa
Showing 1 changed file with 11 additions and 3 deletions.
Expand Up @@ -12,7 +12,6 @@
*/
package ai.djl.integration.tests.nn;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.engine.Engine;
Expand Down Expand Up @@ -57,6 +56,7 @@
import java.nio.file.Files;
import java.util.Arrays;
import org.testng.Assert;
import org.testng.SkipException;
import org.testng.annotations.Test;

public class BlockCoreTest {
Expand Down Expand Up @@ -207,12 +207,16 @@ public void testBatchNorm() throws IOException, MalformedModelException {
@SuppressWarnings("try")
@Test
public void testLayerNorm() throws IOException, MalformedModelException {
if (!"PyTorch".equals(Engine.getInstance().getEngineName())) {
throw new SkipException("Only works for PyTorch engine.");
}

TrainingConfig config =
new DefaultTrainingConfig(Loss.l2Loss())
.optInitializer(Initializer.ONES, Parameter.Type.WEIGHT);

Block block = LayerNorm.builder().build();
try (Model model = Model.newInstance("model", Device.cpu(), "PyTorch")) {
try (Model model = Model.newInstance("model")) {
model.setBlock(block);

try (Trainer trainer = model.newTrainer(config)) {
Expand All @@ -234,12 +238,16 @@ public void testLayerNorm() throws IOException, MalformedModelException {
@SuppressWarnings("try")
@Test
public void test2LayerNorm() throws IOException, MalformedModelException {
if (!"PyTorch".equals(Engine.getInstance().getEngineName())) {
throw new SkipException("Only works for PyTorch engine.");
}

TrainingConfig config =
new DefaultTrainingConfig(Loss.l2Loss())
.optInitializer(Initializer.ONES, Parameter.Type.WEIGHT);

Block block = LayerNorm.builder().axis(2, 3).build();
try (Model model = Model.newInstance("model", Device.cpu(), "PyTorch")) {
try (Model model = Model.newInstance("model")) {
model.setBlock(block);

try (Trainer trainer = model.newTrainer(config)) {
Expand Down

0 comments on commit 79fa6fa

Please sign in to comment.