In [1]:
import torch
import pytorch_model_summary

from models.msire import MultiScaleIREncoder, DilatedWindowedIREncoder, ExpandedWindowedIREncoder
from test import test
from utils.dataloader import get_dataloader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Multi-scale IREncoder, d = 152, window = 16, layer = 3*6=18, head = 4
device = torch.device('cuda')
multiirencoder = MultiScaleIREncoder(img_res=64, d_embed=152, n_layer=[3,3,3,3,3,3], n_head=4,
                                     hidden_dim_rate=3, window_size=16, sr_upscale=2, test_version=True).to(device)

print(pytorch_model_summary.summary(multiirencoder,
                                    torch.zeros(1, 3, 64, 64, device=device),
                                    show_input=False))

print('GFlops :', multiirencoder.flops(64, 64) / 1000000000)

------------------------------------------------------------------------
      Layer (type)         Output Shape         Param #     Tr. Param #
          Conv2d-1     [1, 152, 64, 64]           4,256           4,256
       LayerNorm-2     [1, 64, 64, 152]             304             304
          Linear-3     [1, 64, 64, 152]          23,256          23,256
    DecoderLayer-4     [1, 64, 64, 152]         478,588         478,588
    DecoderLayer-5     [1, 64, 64, 152]         478,588         478,588
    DecoderLayer-6     [1, 64, 64, 152]         478,588         478,588
          Conv2d-7     [1, 456, 64, 64]          69,768          69,768
            GELU-8     [1, 456, 64, 64]               0               0
          Conv2d-9     [1, 456, 64, 64]           4,560           4,560
         Conv2d-10     [1, 152, 64, 64]          69,464          69,464
      LayerNorm-11     [1, 64, 64, 152]             304             304
         Linear-12     [1, 64, 64, 152]          23,256        

In [17]:
# Multi-scale IREncoder, d = 124, window = 12, layer = 4*8=32, head = 4
device = torch.device('cuda')
multiirencoder = MultiScaleIREncoder(img_res=72, d_embed=124, n_layer=[4,4,4,4,4,4,4,4], n_head=4,
                                     hidden_dim_rate=3, window_size=12, sr_upscale=2, test_version=True).to(device)

print(pytorch_model_summary.summary(multiirencoder,
                                    torch.zeros(1, 3, 72, 72, device=device),
                                    show_input=False))

print('GFlops :', multiirencoder.flops(64, 64) / 1000000000)

------------------------------------------------------------------------
      Layer (type)         Output Shape         Param #     Tr. Param #
          Conv2d-1     [1, 124, 72, 72]           3,472           3,472
       LayerNorm-2     [1, 72, 72, 124]             248             248
          Linear-3     [1, 72, 72, 124]          15,500          15,500
    DecoderLayer-4     [1, 72, 72, 124]         318,752         318,752
    DecoderLayer-5     [1, 72, 72, 124]         318,752         318,752
    DecoderLayer-6     [1, 72, 72, 124]         318,752         318,752
    DecoderLayer-7     [1, 72, 72, 124]         318,752         318,752
          Conv2d-8     [1, 372, 72, 72]          46,500          46,500
            GELU-9     [1, 372, 72, 72]               0               0
         Conv2d-10     [1, 372, 72, 72]           3,720           3,720
         Conv2d-11     [1, 124, 72, 72]          46,252          46,252
      LayerNorm-12     [1, 72, 72, 124]             248        

In [4]:
# Multi-scale IREncoder, d = 124, window = 12, layer = 4*8=32, head = 4
device = torch.device('cuda')
multiirencoder = MultiScaleIREncoder(img_res=56, d_embed=136, n_layer=[4,4,4,4,4,4], n_head=4,
                                     hidden_dim_rate=3, window_size=14, sr_upscale=2, test_version=True).to(device)

print(pytorch_model_summary.summary(multiirencoder,
                                    torch.zeros(1, 3, 56, 56, device=device),
                                    show_input=False))

print('GFlops :', multiirencoder.flops(64, 64) / 1000000000)

------------------------------------------------------------------------
      Layer (type)         Output Shape         Param #     Tr. Param #
          Conv2d-1     [1, 136, 56, 56]           3,808           3,808
       LayerNorm-2     [1, 56, 56, 136]             272             272
          Linear-3     [1, 56, 56, 136]          18,632          18,632
    DecoderLayer-4     [1, 56, 56, 136]         384,332         384,332
    DecoderLayer-5     [1, 56, 56, 136]         384,332         384,332
    DecoderLayer-6     [1, 56, 56, 136]         384,332         384,332
    DecoderLayer-7     [1, 56, 56, 136]         384,332         384,332
          Conv2d-8     [1, 408, 56, 56]          55,896          55,896
            GELU-9     [1, 408, 56, 56]               0               0
         Conv2d-10     [1, 408, 56, 56]           4,080           4,080
         Conv2d-11     [1, 136, 56, 56]          55,624          55,624
      LayerNorm-12     [1, 56, 56, 136]             272        

In [3]:
# Lightweight Multi-scale IREncoder, d = 46, window = 10, layer = 3*4=12, head = 2
device = torch.device('cuda')
multiirencoder = MultiScaleIREncoder(img_res=60, d_embed=46, n_layer=[3,3,3,3], n_head=2,
                                     hidden_dim_rate=3, window_size=10, sr_upscale=2, test_version=True).to(device)

print(pytorch_model_summary.summary(multiirencoder,
                                    torch.zeros(1, 3, 60, 60, device=device),
                                    show_input=False))

print('GFlops :', multiirencoder.flops(640, 360) / 1000000000)

------------------------------------------------------------------------
      Layer (type)         Output Shape         Param #     Tr. Param #
          Conv2d-1      [1, 46, 60, 60]           1,288           1,288
       LayerNorm-2      [1, 60, 60, 46]              92              92
          Linear-3      [1, 60, 60, 46]           2,162           2,162
    DecoderLayer-4      [1, 60, 60, 46]          53,096          53,096
    DecoderLayer-5      [1, 60, 60, 46]          53,096          53,096
    DecoderLayer-6      [1, 60, 60, 46]          53,096          53,096
          Conv2d-7     [1, 138, 60, 60]           6,486           6,486
            GELU-8     [1, 138, 60, 60]               0               0
          Conv2d-9     [1, 138, 60, 60]           1,380           1,380
         Conv2d-10      [1, 46, 60, 60]           6,394           6,394
      LayerNorm-11      [1, 60, 60, 46]              92              92
         Linear-12      [1, 60, 60, 46]           2,162        

In [11]:
# Lightweight Multi-scale IREncoder, d = 46, window = 8, layer = 4*4=16, head = 2
device = torch.device('cuda')
multiirencoder = MultiScaleIREncoder(img_res=48, d_embed=46, n_layer=[4,4,4,4], n_head=2,
                                     hidden_dim_rate=2, window_size=8, sr_upscale=2, test_version=True).to(device)

print(pytorch_model_summary.summary(multiirencoder,
                                    torch.zeros(1, 3, 48, 48, device=device),
                                    show_input=False))

print('GFlops :', multiirencoder.flops(640, 360) / 1000000000)

-----------------------------------------------------------------------
      Layer (type)        Output Shape         Param #     Tr. Param #
          Conv2d-1     [1, 46, 48, 48]           1,288           1,288
       LayerNorm-2     [1, 48, 48, 46]              92              92
          Linear-3     [1, 48, 48, 46]           2,162           2,162
    DecoderLayer-4     [1, 48, 48, 46]          45,910          45,910
    DecoderLayer-5     [1, 48, 48, 46]          45,910          45,910
    DecoderLayer-6     [1, 48, 48, 46]          45,910          45,910
    DecoderLayer-7     [1, 48, 48, 46]          45,910          45,910
          Conv2d-8     [1, 92, 48, 48]           4,324           4,324
            GELU-9     [1, 92, 48, 48]               0               0
         Conv2d-10     [1, 92, 48, 48]             920             920
         Conv2d-11     [1, 46, 48, 48]           4,278           4,278
      LayerNorm-12     [1, 48, 48, 46]              92              92
     

In [7]:
# Dilated-windowed IREncoder, d = 128, window = 16, layer = 6*6=36, head = 8
device = torch.device('cuda')
irencoder = DilatedWindowedIREncoder(img_res=64, d_embed=128, n_layer=[6,6,6,6,6,6], n_head=8,
                                     hidden_dim_rate=2, conv_hidden_rate=2, window_size=16, sr_upscale=2).to(device)

print(pytorch_model_summary.summary(irencoder,
                                    torch.zeros(1, 3, 64, 64, device=device),
                                    show_input=False))

print('GFlops :', irencoder.flops(64, 64) / 1000000000)

------------------------------------------------------------------------------
            Layer (type)         Output Shape         Param #     Tr. Param #
                Conv2d-1     [1, 128, 64, 64]           3,584           3,584
             LayerNorm-2     [1, 64, 64, 128]             256             256
                Linear-3     [1, 64, 64, 128]          16,512          16,512
    TransformerEncoder-4     [1, 64, 64, 128]       1,310,256       1,310,256
                Conv2d-5     [1, 512, 64, 64]          66,048          66,048
                  GELU-6     [1, 512, 64, 64]               0               0
                Conv2d-7     [1, 512, 64, 64]           5,120           5,120
                Conv2d-8     [1, 128, 64, 64]          65,664          65,664
             LayerNorm-9     [1, 64, 64, 128]             256             256
               Linear-10     [1, 64, 64, 128]          16,512          16,512
   TransformerEncoder-11     [1, 64, 64, 128]       1,310,256  

In [13]:
# Lightweight Dilated-windowed IREncoder, d = 32, window = 12, layer = 4*6=24, head = 4
device = torch.device('cuda')
irencoder = DilatedWindowedIREncoder(img_res=64, d_embed=42, n_layer=[6,6,6,6], n_head=6,
                                     hidden_dim_rate=2, conv_hidden_rate=2, window_size=8, sr_upscale=2).to(device)

print(pytorch_model_summary.summary(irencoder,
                                    torch.zeros(1, 3, 64, 64, device=device),
                                    show_input=False))

print('GFlops :', irencoder.flops(640, 360) / 1000000000)

------------------------------------------------------------------------------
            Layer (type)         Output Shape         Param #     Tr. Param #
                Conv2d-1      [1, 42, 64, 64]           1,176           1,176
             LayerNorm-2      [1, 64, 64, 42]              84              84
                Linear-3      [1, 64, 64, 42]           1,806           1,806
    TransformerEncoder-4      [1, 64, 64, 42]         162,828         162,828
                Conv2d-5     [1, 168, 64, 64]           7,224           7,224
                  GELU-6     [1, 168, 64, 64]               0               0
                Conv2d-7     [1, 168, 64, 64]           1,680           1,680
                Conv2d-8      [1, 42, 64, 64]           7,098           7,098
             LayerNorm-9      [1, 64, 64, 42]              84              84
               Linear-10      [1, 64, 64, 42]           1,806           1,806
   TransformerEncoder-11      [1, 64, 64, 42]         162,828  

In [2]:
# Expanded-windowed IREncoder, d = 144, window = 16, layer = 6*5=30, head = 8
device = torch.device('cuda')
irencoder = ExpandedWindowedIREncoder(img_res=64, d_embed=144, n_layer=[6,6,6,6,6], n_head=8,
                                      hidden_dim_rate=2, conv_hidden_rate=2, window_size=16, sr_upscale=2).to(device)

print(pytorch_model_summary.summary(irencoder,
                                    torch.zeros(1, 3, 64, 64, device=device),
                                    show_input=False))

print('GFlops :', irencoder.flops(64, 64) / 1000000000)

------------------------------------------------------------------------------
            Layer (type)         Output Shape         Param #     Tr. Param #
                Conv2d-1     [1, 144, 64, 64]           4,032           4,032
             LayerNorm-2     [1, 64, 64, 144]             288             288
                Linear-3     [1, 64, 64, 144]          20,880          20,880
    TransformerEncoder-4     [1, 64, 64, 144]       1,737,984       1,737,984
                Conv2d-5     [1, 576, 64, 64]          83,520          83,520
                  GELU-6     [1, 576, 64, 64]               0               0
                Conv2d-7     [1, 576, 64, 64]           5,760           5,760
                Conv2d-8     [1, 144, 64, 64]          83,088          83,088
             LayerNorm-9     [1, 64, 64, 144]             288             288
               Linear-10     [1, 64, 64, 144]          20,880          20,880
   TransformerEncoder-11     [1, 64, 64, 144]       1,737,984  

In [2]:
# Expanded-windowed IREncoder, d = 160, window = 16, layer = 6*6=36, head = 8
device = torch.device('cuda')
irencoder = ExpandedWindowedIREncoder(img_res=64, d_embed=160, n_layer=[6,6,6,6,6,6], n_head=8,
                                      hidden_dim_rate=2, conv_hidden_rate=1.2, window_size=16, sr_upscale=2).to(device)

print(pytorch_model_summary.summary(irencoder,
                                    torch.zeros(1, 3, 64, 64, device=device),
                                    show_input=False))

print('GFlops :', irencoder.flops(64, 64) / 1000000000)

------------------------------------------------------------------------------
            Layer (type)         Output Shape         Param #     Tr. Param #
                Conv2d-1     [1, 160, 64, 64]           4,480           4,480
             LayerNorm-2     [1, 64, 64, 160]             320             320
                Linear-3     [1, 64, 64, 160]          25,760          25,760
    TransformerEncoder-4     [1, 64, 64, 160]       1,559,616       1,559,616
                Conv2d-5     [1, 640, 64, 64]         103,040         103,040
                  GELU-6     [1, 640, 64, 64]               0               0
                Conv2d-7     [1, 640, 64, 64]           6,400           6,400
                Conv2d-8     [1, 160, 64, 64]         102,560         102,560
             LayerNorm-9     [1, 64, 64, 160]             320             320
               Linear-10     [1, 64, 64, 160]          25,760          25,760
   TransformerEncoder-11     [1, 64, 64, 160]       1,559,616  

In [8]:
# Expanded-windowed IREncoder, d = 144, window = 12, layer = 6*6=36, head = 8
device = torch.device('cuda')
irencoder = ExpandedWindowedIREncoder(img_res=64, d_embed=[144,144,168,168,192,192], n_layer=[6,6,6,6,6,6], n_head=6,
                                      hidden_dim_rate=2, conv_hidden_rate=1, window_size=16, sr_upscale=2).to(device)

print(pytorch_model_summary.summary(irencoder,
                                    torch.zeros(1, 3, 64, 64, device=device),
                                    show_input=False))

print('GFlops :', irencoder.flops(64, 64) / 1000000000)

------------------------------------------------------------------------------
            Layer (type)         Output Shape         Param #     Tr. Param #
                Conv2d-1     [1, 192, 64, 64]           5,376           5,376
             LayerNorm-2     [1, 64, 64, 192]             384             384
                Linear-3     [1, 64, 64, 144]          27,792          27,792
    TransformerEncoder-4     [1, 64, 64, 144]       1,155,312       1,155,312
                Conv2d-5     [1, 576, 64, 64]          83,520          83,520
                  GELU-6     [1, 576, 64, 64]               0               0
                Conv2d-7     [1, 576, 64, 64]           5,760           5,760
                Conv2d-8     [1, 192, 64, 64]         110,784         110,784
             LayerNorm-9     [1, 64, 64, 192]             384             384
               Linear-10     [1, 64, 64, 144]          27,792          27,792
   TransformerEncoder-11     [1, 64, 64, 144]       1,155,312  

In [3]:
# Lightweight Expanded-windowed IREncoder, d = 42, window = 8, layer = 6*4=24, head = 4
device = torch.device('cuda')
irencoder = ExpandedWindowedIREncoder(img_res=64, d_embed=42, n_layer=[6,6,6,6], n_head=4,
                                      hidden_dim_rate=2, conv_hidden_rate=2, window_size=8, sr_upscale=2).to(device)

print(pytorch_model_summary.summary(irencoder,
                                    torch.zeros(1, 3, 64, 64, device=device),
                                    show_input=False))

print('GFlops :', irencoder.flops(640, 360) / 1000000000)

------------------------------------------------------------------------------
            Layer (type)         Output Shape         Param #     Tr. Param #
                Conv2d-1      [1, 42, 64, 64]           1,176           1,176
             LayerNorm-2      [1, 64, 64, 42]              84              84
                Linear-3      [1, 64, 64, 42]           1,806           1,806
    TransformerEncoder-4      [1, 64, 64, 42]         180,456         180,456
                Conv2d-5     [1, 168, 64, 64]           7,224           7,224
                  GELU-6     [1, 168, 64, 64]               0               0
                Conv2d-7     [1, 168, 64, 64]           1,680           1,680
                Conv2d-8      [1, 42, 64, 64]           7,098           7,098
             LayerNorm-9      [1, 64, 64, 42]              84              84
               Linear-10      [1, 64, 64, 42]           1,806           1,806
   TransformerEncoder-11      [1, 64, 64, 42]         180,456  

In [13]:
# Lightweight Expanded-windowed IREncoder, d = 42, window = 8, layer = 6*4=24, head = 4
device = torch.device('cuda')
irencoder = ExpandedWindowedIREncoder(img_res=60, d_embed=56, n_layer=[6,6,6], n_head=4,
                                      hidden_dim_rate=2, conv_hidden_rate=1, window_size=12, sr_upscale=2).to(device)

print(pytorch_model_summary.summary(irencoder,
                                    torch.zeros(1, 3, 60, 60, device=device),
                                    show_input=False))

print('GFlops :', irencoder.flops(640, 360) / 1000000000)

------------------------------------------------------------------------------
            Layer (type)         Output Shape         Param #     Tr. Param #
                Conv2d-1      [1, 56, 60, 60]           1,568           1,568
             LayerNorm-2      [1, 60, 60, 56]             112             112
                Linear-3      [1, 60, 60, 56]           3,192           3,192
    TransformerEncoder-4      [1, 60, 60, 56]         210,864         210,864
                Conv2d-5     [1, 224, 60, 60]          12,768          12,768
                  GELU-6     [1, 224, 60, 60]               0               0
                Conv2d-7     [1, 224, 60, 60]           2,240           2,240
                Conv2d-8      [1, 56, 60, 60]          12,600          12,600
             LayerNorm-9      [1, 60, 60, 56]             112             112
               Linear-10      [1, 60, 60, 56]           3,192           3,192
   TransformerEncoder-11      [1, 60, 60, 56]         210,864  

In [8]:
# Lightweight Expanded-windowed IREncoder, d = 42, window = 8, layer = 6*4=24, head = 4
device = torch.device('cuda')
irencoder = ExpandedWindowedIREncoder(img_res=60, d_embed=48, n_layer=[6,6,6,6], n_head=3,
                                      hidden_dim_rate=2, conv_hidden_rate=1, window_size=12, sr_upscale=2).to(device)

print(pytorch_model_summary.summary(irencoder,
                                    torch.zeros(1, 3, 60, 60, device=device),
                                    show_input=False))

print('GFlops :', irencoder.flops(640, 360) / 1000000000)

------------------------------------------------------------------------------
            Layer (type)         Output Shape         Param #     Tr. Param #
                Conv2d-1      [1, 48, 60, 60]           1,344           1,344
             LayerNorm-2      [1, 60, 60, 48]              96              96
                Linear-3      [1, 60, 60, 48]           2,352           2,352
    TransformerEncoder-4      [1, 60, 60, 48]         159,336         159,336
                Conv2d-5     [1, 192, 60, 60]           9,408           9,408
                  GELU-6     [1, 192, 60, 60]               0               0
                Conv2d-7     [1, 192, 60, 60]           1,920           1,920
                Conv2d-8      [1, 48, 60, 60]           9,264           9,264
             LayerNorm-9      [1, 60, 60, 48]              96              96
               Linear-10      [1, 60, 60, 48]           2,352           2,352
   TransformerEncoder-11      [1, 60, 60, 48]         159,336  

In [3]:
# Lightweight Expanded-windowed IREncoder, d = 42, window = 8, layer = 6*4=24, head = 4
device = torch.device('cuda')
irencoder = ExpandedWindowedIREncoder(img_res=60, d_embed=[36,48,48,60], n_layer=[6,6,6,6], n_head=3, hidden_dim_rate=2,
                                      conv_hidden_rate=1, residual_hidden_rate=2, window_size=12, sr_upscale=2).to(device)

print(pytorch_model_summary.summary(irencoder,
                                    torch.zeros(1, 3, 60, 60, device=device),
                                    show_input=False))

print('GFlops :', irencoder.flops(640, 360) / 1000000000)

------------------------------------------------------------------------------
            Layer (type)         Output Shape         Param #     Tr. Param #
                Conv2d-1      [1, 60, 60, 60]           1,680           1,680
             LayerNorm-2      [1, 60, 60, 60]             120             120
                Linear-3      [1, 60, 60, 36]           2,196           2,196
    TransformerEncoder-4      [1, 60, 60, 36]         103,968         103,968
                Conv2d-5      [1, 72, 60, 60]           2,664           2,664
                  GELU-6      [1, 72, 60, 60]               0               0
                Conv2d-7      [1, 72, 60, 60]             720             720
                Conv2d-8      [1, 60, 60, 60]           4,380           4,380
             LayerNorm-9      [1, 60, 60, 60]             120             120
               Linear-10      [1, 60, 60, 48]           2,928           2,928
   TransformerEncoder-11      [1, 60, 60, 48]         159,336  

In [2]:
# Parameters adopting from old model
model = MultiScaleIREncoder(img_res=48, d_embed=46, n_layer=[4,4,4,4], n_head=2,
                            hidden_dim_rate=2, window_size=8, sr_upscale=2)
state_dict = torch.load('logs/X2/20230723_062415/state_dict/state_dict_epoch_1300.pt', map_location='cpu')
model.load_state_dict(state_dict, strict=False)
sub_dict = {'weight' : state_dict['initial_feature_mappings.0.weight'],
            'bias' : state_dict['initial_feature_mappings.0.bias']}
model.initial_feature_mapping.load_state_dict(sub_dict)

RuntimeError: Error(s) in loading state_dict for MultiScaleIREncoder:
	size mismatch for feature_addition_modules.0.0.weight: copying a param with shape torch.Size([138, 46, 1, 1]) from checkpoint, the shape in current model is torch.Size([92, 46, 1, 1]).
	size mismatch for feature_addition_modules.0.0.bias: copying a param with shape torch.Size([138]) from checkpoint, the shape in current model is torch.Size([92]).
	size mismatch for feature_addition_modules.0.2.weight: copying a param with shape torch.Size([138, 1, 3, 3]) from checkpoint, the shape in current model is torch.Size([92, 1, 3, 3]).
	size mismatch for feature_addition_modules.0.2.bias: copying a param with shape torch.Size([138]) from checkpoint, the shape in current model is torch.Size([92]).
	size mismatch for feature_addition_modules.0.4.weight: copying a param with shape torch.Size([46, 138, 1, 1]) from checkpoint, the shape in current model is torch.Size([46, 92, 1, 1]).
	size mismatch for feature_addition_modules.1.0.weight: copying a param with shape torch.Size([138, 46, 1, 1]) from checkpoint, the shape in current model is torch.Size([92, 46, 1, 1]).
	size mismatch for feature_addition_modules.1.0.bias: copying a param with shape torch.Size([138]) from checkpoint, the shape in current model is torch.Size([92]).
	size mismatch for feature_addition_modules.1.2.weight: copying a param with shape torch.Size([138, 1, 3, 3]) from checkpoint, the shape in current model is torch.Size([92, 1, 3, 3]).
	size mismatch for feature_addition_modules.1.2.bias: copying a param with shape torch.Size([138]) from checkpoint, the shape in current model is torch.Size([92]).
	size mismatch for feature_addition_modules.1.4.weight: copying a param with shape torch.Size([46, 138, 1, 1]) from checkpoint, the shape in current model is torch.Size([46, 92, 1, 1]).
	size mismatch for feature_addition_modules.2.0.weight: copying a param with shape torch.Size([138, 46, 1, 1]) from checkpoint, the shape in current model is torch.Size([92, 46, 1, 1]).
	size mismatch for feature_addition_modules.2.0.bias: copying a param with shape torch.Size([138]) from checkpoint, the shape in current model is torch.Size([92]).
	size mismatch for feature_addition_modules.2.2.weight: copying a param with shape torch.Size([138, 1, 3, 3]) from checkpoint, the shape in current model is torch.Size([92, 1, 3, 3]).
	size mismatch for feature_addition_modules.2.2.bias: copying a param with shape torch.Size([138]) from checkpoint, the shape in current model is torch.Size([92]).
	size mismatch for feature_addition_modules.2.4.weight: copying a param with shape torch.Size([46, 138, 1, 1]) from checkpoint, the shape in current model is torch.Size([46, 92, 1, 1]).
	size mismatch for feature_addition_modules.3.0.weight: copying a param with shape torch.Size([138, 46, 1, 1]) from checkpoint, the shape in current model is torch.Size([92, 46, 1, 1]).
	size mismatch for feature_addition_modules.3.0.bias: copying a param with shape torch.Size([138]) from checkpoint, the shape in current model is torch.Size([92]).
	size mismatch for feature_addition_modules.3.2.weight: copying a param with shape torch.Size([138, 1, 3, 3]) from checkpoint, the shape in current model is torch.Size([92, 1, 3, 3]).
	size mismatch for feature_addition_modules.3.2.bias: copying a param with shape torch.Size([138]) from checkpoint, the shape in current model is torch.Size([92]).
	size mismatch for feature_addition_modules.3.4.weight: copying a param with shape torch.Size([46, 138, 1, 1]) from checkpoint, the shape in current model is torch.Size([46, 92, 1, 1]).

In [3]:
model.eval()
dataloader = get_dataloader(setting='test', dataset='Set5')
psnr, ssim = test(model, dataloader, None, len(dataloader), 2, torch.device('cpu'))
print(psnr)
print(ssim)

38.13659591258586
0.9611594567132649


In [6]:
torch.save(model, 'logs/X2/20230720_045925/model/model.pt')
torch.save(model.state_dict(), 'logs/X2/20230720_045925/state_dict/state_dict_epoch_1540.pt')