Skip to content

Commit

Permalink
Inline nn.Flatten in test for PyTorch 1.1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
sublee committed Nov 9, 2019
1 parent ff3bab8 commit 995be0e
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion tests/test_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,12 @@ def forward(self, x):


def test_balance_by_time_loop_resets_input():
model = nn.Sequential(nn.Conv2d(3, 2, 1), nn.Flatten(), nn.Linear(128, 10))
# nn.Flatten was introduced at PyTorch 1.2.0.
class Flatten(nn.Module):
def forward(self, x):
return x.flatten(1)

model = nn.Sequential(nn.Conv2d(3, 2, 1), Flatten(), nn.Linear(128, 10))
sample = torch.rand(10, 3, 8, 8)
balance = balance_by_time(2, model, sample, device='cpu')
assert balance == [1, 2]
Expand Down

0 comments on commit 995be0e

Please sign in to comment.