Skip to content

Commit

Permalink
add cpu support
Browse files Browse the repository at this point in the history
  • Loading branch information
bongjun committed Jun 18, 2019
1 parent 5090879 commit a7e4840
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 4 deletions.
16 changes: 13 additions & 3 deletions evaluate_ensemble.py
Expand Up @@ -29,7 +29,11 @@ def predict_ensemble(mel_list, test_file_idxs, model1, model2):

test_x = mel_list[idx]
test_x = np.reshape(test_x,(1,1,test_x.shape[0],test_x.shape[1]))
test_x = torch.from_numpy(test_x).cuda().float()

if torch.cuda.is_available():
test_x = torch.from_numpy(test_x).cuda().float()
else:
test_x = torch.from_numpy(test_x).float()

model_output = (model1(test_x) + model2(test_x))/2

Expand Down Expand Up @@ -67,19 +71,25 @@ def evaluate_ensemble(annotation_path, taxonomy_path, mel_dir, models_dir1,
model_filename = model_list[np.argmin(val_loss)]

model1 = MyCNN()
model1.load_state_dict(torch.load(os.path.join(models_dir1, model_filename)))
if torch.cuda.is_available():
model1.load_state_dict(torch.load(os.path.join(models_dir1, model_filename)))
model1.cuda()
else:
model1.load_state_dict(torch.load(os.path.join(models_dir1, model_filename), map_location='cpu'))

model1.eval()

model_list = [f for f in os.listdir(models_dir2) if 'pth' in f]
val_loss = [float(f.split('_')[-1][:-4]) for f in model_list]
model_filename = model_list[np.argmin(val_loss)]

model2 = MyCNN()
model2.load_state_dict(torch.load(os.path.join(models_dir2, model_filename)))
if torch.cuda.is_available():
model2.load_state_dict(torch.load(os.path.join(models_dir2, model_filename)))
model2.cuda()
else:
model2.load_state_dict(torch.load(os.path.join(models_dir2, model_filename), map_location='cpu'))

model2.eval()

y_pred = predict_ensemble(mel_list, test_file_idxs, model1, model2)
Expand Down
6 changes: 5 additions & 1 deletion train.py
Expand Up @@ -104,7 +104,11 @@ def predict(mel_list, test_file_idxs, model):

test_x = mel_list[idx]
test_x = np.reshape(test_x,(1,1,test_x.shape[0],test_x.shape[1]))
test_x = torch.from_numpy(test_x).cuda().float()

if torch.cuda.is_available():
test_x = torch.from_numpy(test_x).cuda().float()
else:
test_x = torch.from_numpy(test_x).float()

model_output = model(test_x)

Expand Down
Binary file modified urban_sound_tagging_baseline/__pycache__/metrics.cpython-35.pyc
Binary file not shown.
Binary file modified vggish_utils/__pycache__/mel_features.cpython-35.pyc
Binary file not shown.
Binary file modified vggish_utils/__pycache__/vggish_params.cpython-35.pyc
Binary file not shown.

0 comments on commit a7e4840

Please sign in to comment.