Skip to content

Commit

Permalink
Remove ImageNet1k mean values and organize losses into separate files
Browse files Browse the repository at this point in the history
  • Loading branch information
muslll committed Apr 7, 2024
1 parent f23e1ca commit 417432c
Show file tree
Hide file tree
Showing 18 changed files with 490 additions and 455 deletions.
2 changes: 1 addition & 1 deletion neosr/archs/atd_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,7 @@ def __init__(
self.no_norm = None

if in_chans == 3:
rgb_mean = (0.4488, 0.4371, 0.4040)
rgb_mean = (0.5, 0.5, 0.5)
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
else:
self.mean = torch.zeros(1, 1, 1, 1)
Expand Down
2 changes: 1 addition & 1 deletion neosr/archs/craft_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ def __init__(self,
self.num_feat = num_feat
self.num_out_ch = num_out_ch
if in_chans == 3:
rgb_mean = (0.4488, 0.4371, 0.4040)
rgb_mean = (0.5, 0.5, 0.5)
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
else:
self.mean = torch.zeros(1, 1, 1, 1)
Expand Down
2 changes: 1 addition & 1 deletion neosr/archs/dat_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ def __init__(self,
num_feat = 64
self.img_range = img_range
if in_chans == 3:
rgb_mean = (0.4488, 0.4371, 0.4040)
rgb_mean = (0.5, 0.5, 0.5)
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
else:
self.mean = torch.zeros(1, 1, 1, 1)
Expand Down
2 changes: 1 addition & 1 deletion neosr/archs/hat_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ def __init__(self,
num_feat = 64
self.img_range = img_range
if in_chans == 3:
rgb_mean = (0.4488, 0.4371, 0.4040)
rgb_mean = (0.5, 0.5, 0.5)
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
else:
self.mean = torch.zeros(1, 1, 1, 1)
Expand Down
2 changes: 1 addition & 1 deletion neosr/archs/hdsrnet_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ class MeanShift(nn.Conv2d):
def __init__(
self,
rgb_range,
rgb_mean=(0.4488, 0.4371, 0.4040),
rgb_mean=(0.5, 0.5, 0.5),
rgb_std=(1.0, 1.0, 1.0),
sign=-1,
):
Expand Down
2 changes: 1 addition & 1 deletion neosr/archs/rgt_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,7 @@ def __init__(
num_feat = 64
self.img_range = img_range
if in_chans == 3:
rgb_mean = (0.4488, 0.4371, 0.4040)
rgb_mean = (0.5, 0.5, 0.5)
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
else:
self.mean = torch.zeros(1, 1, 1, 1)
Expand Down
2 changes: 1 addition & 1 deletion neosr/archs/span_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def __init__(self,
bias=True,
norm=False,
img_range=1.0,
rgb_mean=(0.4488, 0.4371, 0.4040)
rgb_mean=(0.5, 0.5, 0.5)
):
super(span, self).__init__()

Expand Down
2 changes: 1 addition & 1 deletion neosr/archs/swinir_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ def __init__(self,
num_feat = 64
self.img_range = img_range
if in_chans == 3:
rgb_mean = (0.4488, 0.4371, 0.4040)
rgb_mean = (0.5, 0.5, 0.5)
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
else:
self.mean = torch.zeros(1, 1, 1, 1)
Expand Down
Loading

0 comments on commit 417432c

Please sign in to comment.