In [1]:
import torch
import torch.nn as nn

In [2]:
N, C, F, H, W = 50, 1, 5, 64, 64
vid_sample = torch.randn(N, C, F, H, W)

In [4]:
# Without stride; same output HxW
conv3d = nn.Conv3d(1, 5,
                   stride=(1, 1, 1),
                   kernel_size=(3,3,3), padding=(1,1,1))

conv3d(vid_sample).shape

torch.Size([50, 5, 5, 64, 64])

In [3]:
# With stride; Downsample by halving height and width
conv3d = nn.Conv3d(1, 5,
                   stride=(1, 2, 2),
                   kernel_size=(3,3,3), padding=(1,1,1))

conv3d(vid_sample).shape

torch.Size([50, 5, 5, 32, 32])

In [6]:
# Tranpose; Same height and width
conv3dtranspose = nn.ConvTranspose3d(1, 5,
                           stride=(1,1,1),kernel_size=(3,3,3), 
                           padding=(1,1,1))

conv3dtranspose(vid_sample).shape

torch.Size([50, 5, 5, 64, 64])

In [5]:
# Tranpose upsample; Doubles the height and width
conv3dtranspose = nn.ConvTranspose3d(1, 5,
                           stride=(1,2,2),kernel_size=(3,3,3), 
                           padding=(1,1,1), output_padding=(0,1,1))

conv3dtranspose(vid_sample).shape

torch.Size([50, 5, 5, 128, 128])

In [7]:
class PSNR(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, pred_frames, target_frames):
        N, C, F, H, W = pred_frames.shape
        out = np.zeros((N, F))
        for vid_idx in range(N):
            for frame_idx in range(F):
                psnr = peak_signal_noise_ratio(
                    pred_frames[vid_idx, frame_idx].numpy(),
                    target_frames[vid_idx, frame_idx].numpy(),
                    data_range=1.0
                )
                out[vid_idx, frame_idx] = psnr
        return out.mean()

class SSIM(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, pred_frames, target_frames):
        N, C, F, H, W = pred_frames.shape
        out = np.zeros((N, F))
        for vid_idx in range(N):
            for frame_idx in range(F):
                ssim = structural_similarity(
                    pred_frames[vid_idx, frame_idx].numpy(),
                    target_frames[vid_idx, frame_idx].numpy(),
                    data_range=1.0
                )
                out[vid_idx, frame_idx] = ssim
        return out.mean()