Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…into main
  • Loading branch information
chaofengc committed Dec 14, 2023
2 parents d25d50a + 921bd75 commit 8a81159
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cp
# create metric with default setting
iqa_metric = pyiqa.create_metric('lpips', device=device)
# Note that gradient propagation is disabled by default. set as_loss=True to enable it as a loss function.
iqa_loss = pyiqa.create_metric('lpips', device=device, as_loss=True)
# iqa_loss = pyiqa.create_metric('lpips', device=device, as_loss=True)
# create metric with custom setting
iqa_metric = pyiqa.create_metric('psnr', test_y_channel=True, color_space='ycbcr').to(device)
Expand All @@ -94,7 +94,6 @@ print(iqa_metric.lower_better)
# example for iqa score inference
# Tensor inputs, img_tensor_x/y: (N, 3, H, W), RGB, 0 ~ 1
score_fr = iqa_metric(img_tensor_x, img_tensor_y)
score_nr = iqa_metric(img_tensor_x)
# img path as inputs.
score_fr = iqa_metric('./ResultsCalibra/dist_dir/I03.bmp', './ResultsCalibra/ref_dir/I03.bmp')
Expand Down
1 change: 1 addition & 0 deletions pyiqa/models/inference_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def forward(self, target, ref=None, **kwargs):
ref = ref.unsqueeze(0)

if self.metric_mode == 'FR':
assert ref is not None, 'Please specify reference image for Full Reference metric'
output = self.net(target.to(self.device), ref.to(self.device), **kwargs)
elif self.metric_mode == 'NR':
output = self.net(target.to(self.device), **kwargs)
Expand Down

0 comments on commit 8a81159

Please sign in to comment.