In [25]:
import os
from models import *
from data import testsets
import torch
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

from data.datasets import *
from trainers.distiller import Distiller
from torch.utils.data import DataLoader
import models
import losses
import datetime
from os.path import join

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
torch.cuda.get_device_name()

'NVIDIA GeForce RTX 3090'

In [26]:
class testArgs:
    gpu_id = 0
    net = 'STMFNet'
    dataset = 'Ucf101_quintuplet'
    metrics = ['PSNR', 'SSIM']
    checkpoint = './models/stmfnet.pth'
    data_dir = 'D:/stmfnet_data'
    out_dir = './tests'
    featc = [64, 128, 256, 512]
    featnet = 'UMultiScaleResNext'
    featnorm = 'batch'
    kernel_size = 5
    dilation = 1
    finetune_pwc = False

class trainArgs:
    gpu_id = 0
    net = 'STMFNet'
    data_dir = 'D:/stmfnet_data'
    out_dir = './train_results'
    load = None
    epochs = 70
    batch_size = 2
    loss = "1*Lap"
    patch_size = 256
    lr = 0.001
    lr_decay = 20
    decay_type = 'step'
    gamma = 0.5
    patience = None
    optimizer = 'ADAMax'
    weight_decay = 0
    featc = [64, 128, 256, 512]
    featnet = 'UMultiScaleResNext'
    featnorm = 'batch'
    kernel_size = 5
    dilation = 1
    finetune_pwc = False

args=trainArgs()

### import data

In [27]:
torch.cuda.set_device(args.gpu_id)

# training sets
vimeo90k_train = Vimeo90k_quintuplet(
    join(args.data_dir, "vimeo_septuplet"),
    train=True,
    crop_sz=(args.patch_size, args.patch_size),
)
bvidvc_train = BVIDVC_quintuplet(
    join(args.data_dir, "bvidvc"), crop_sz=(args.patch_size, args.patch_size)
)

# validation set
vimeo90k_valid = Vimeo90k_quintuplet(
    join(args.data_dir, "vimeo_septuplet"),
    train=False,
    crop_sz=(args.patch_size, args.patch_size),
    augment_s=False,
    augment_t=False,
)

datasets_train = [vimeo90k_train, bvidvc_train]
train_sampler = Sampler(datasets_train, iter=True)

# data loaders
train_loader = DataLoader(
    dataset=train_sampler, batch_size=args.batch_size, shuffle=True, num_workers=0
)
valid_loader = DataLoader(
    dataset=vimeo90k_valid, batch_size=args.batch_size, num_workers=0
)

### teacher model

In [28]:
# Load the model

torch.cuda.set_device(args.gpu_id)

if not os.path.exists(args.out_dir):
    os.mkdir(args.out_dir)

def to_device(data, device):
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

# def load_model(filepath):

#     checkpoint = torch.load(filepath)
#     model = STMFNet(args).cuda()
#     model.load_state_dict(checkpoint['state_dict'])
    
#     return model

# model = load_model("./models/stmfnet.pth")

teacher = to_device(STMFNet(args), device)
teacher.to(device)
checkpoint = torch.load('./models/stmfnet.pth')
teacher.load_state_dict(checkpoint['state_dict'])

<All keys matched successfully>

### student model

In [29]:
import torch
import torch.nn as nn
from models.misc.resnet_3D import r3d_18, Conv_3d, upConv3D
from models.misc import Identity
import cupy_module.adacof as adacof
from cupy_module.softsplat import ModuleSoftsplat
import sys
from torch.nn import functional as F
from utility import moduleNormalize, gaussian_kernel
from models import feature
from models.misc import MIMOGridNet, Upsampler_8tap
from models.misc import PWCNet
from models.misc.pwcnet import backwarp

class UNet3d_18(nn.Module):
    def __init__(self, channels=[32, 64, 96, 128], bn=True):
        super(UNet3d_18, self).__init__()
        growth = 2  # since concatenating previous outputs
        upmode = "transpose"  # use transposeConv to upsample

        self.channels = channels

        self.lrelu = nn.LeakyReLU(0.2, True)

        self.encoder = r3d_18(bn=bn, channels=channels)

        self.decoder = nn.Sequential(
            Conv_3d(
                channels[::-1][0],
                channels[::-1][1],
                kernel_size=3,
                padding=1,
                bias=True,
            ),
            upConv3D(
                channels[::-1][1] * growth,
                channels[::-1][2],
                kernel_size=(3, 4, 4),
                stride=(1, 2, 2),
                padding=(1, 1, 1),
                upmode=upmode,
            ),
            upConv3D(
                channels[::-1][2] * growth,
                channels[::-1][3],
                kernel_size=(3, 4, 4),
                stride=(1, 2, 2),
                padding=(1, 1, 1),
                upmode=upmode,
            ),
            Conv_3d(
                channels[::-1][3] * growth,
                channels[::-1][3],
                kernel_size=3,
                padding=1,
                bias=True,
            ),
            upConv3D(
                channels[::-1][3] * growth,
                channels[::-1][3],
                kernel_size=(3, 4, 4),
                stride=(1, 2, 2),
                padding=(1, 1, 1),
                upmode=upmode,
            ),
        )

        self.feature_fuse = nn.Sequential(
            *(
                [
                    nn.Conv2d(
                        channels[::-1][3] * 5,
                        channels[::-1][3],
                        kernel_size=1,
                        stride=1,
                        bias=False,
                    )
                ]
                + [nn.BatchNorm2d(channels[::-1][3]) if bn else Identity]
            )
        )

        self.outconv = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(channels[::-1][3], 3, kernel_size=7, stride=1, padding=0),
        )

    def forward(self, im1, im3, im5, im7, im4_tilde):
        images = torch.stack((im1, im3, im4_tilde, im5, im7), dim=2)

        x_0, x_1, x_2, x_3, x_4 = self.encoder(images)

        dx_3 = self.lrelu(self.decoder[0](x_4))
        dx_3 = torch.cat([dx_3, x_3], dim=1)

        dx_2 = self.lrelu(self.decoder[1](dx_3))
        dx_2 = torch.cat([dx_2, x_2], dim=1)

        dx_1 = self.lrelu(self.decoder[2](dx_2))
        dx_1 = torch.cat([dx_1, x_1], dim=1)

        dx_0 = self.lrelu(self.decoder[3](dx_1))
        dx_0 = torch.cat([dx_0, x_0], dim=1)

        dx_out = self.lrelu(self.decoder[4](dx_0))
        dx_out = torch.cat(torch.unbind(dx_out, 2), 1)

        out = self.lrelu(self.feature_fuse(dx_out))
        out = self.outconv(out)

        return out


class KernelEstimation(torch.nn.Module):
    def __init__(self, kernel_size):
        super(KernelEstimation, self).__init__()
        self.kernel_size = kernel_size

        def Subnet_offset(ks):
            return torch.nn.Sequential(
                torch.nn.Conv2d(
                    in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1
                ),
                torch.nn.ReLU(inplace=False),
                torch.nn.Conv2d(
                    in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1
                ),
                torch.nn.ReLU(inplace=False),
                torch.nn.Conv2d(
                    in_channels=64, out_channels=ks, kernel_size=3, stride=1, padding=1
                ),
                torch.nn.ReLU(inplace=False),
                torch.nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
                torch.nn.Conv2d(
                    in_channels=ks, out_channels=ks, kernel_size=3, stride=1, padding=1
                ),
            )

        def Subnet_weight(ks):
            return torch.nn.Sequential(
                torch.nn.Conv2d(
                    in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1
                ),
                torch.nn.ReLU(inplace=False),
                torch.nn.Conv2d(
                    in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1
                ),
                torch.nn.ReLU(inplace=False),
                torch.nn.Conv2d(
                    in_channels=64, out_channels=ks, kernel_size=3, stride=1, padding=1
                ),
                torch.nn.ReLU(inplace=False),
                torch.nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
                torch.nn.Conv2d(
                    in_channels=ks, out_channels=ks, kernel_size=3, stride=1, padding=1
                ),
                torch.nn.Softmax(dim=1),
            )

        def Subnet_offset_ds(ks):
            return torch.nn.Sequential(
                torch.nn.Conv2d(
                    in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1
                ),
                torch.nn.ReLU(inplace=False),
                torch.nn.Conv2d(
                    in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1
                ),
                torch.nn.ReLU(inplace=False),
                torch.nn.Conv2d(
                    in_channels=64, out_channels=ks, kernel_size=3, stride=1, padding=1
                ),
            )

        def Subnet_weight_ds(ks):
            return torch.nn.Sequential(
                torch.nn.Conv2d(
                    in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1
                ),
                torch.nn.ReLU(inplace=False),
                torch.nn.Conv2d(
                    in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1
                ),
                torch.nn.ReLU(inplace=False),
                torch.nn.Conv2d(
                    in_channels=64, out_channels=ks, kernel_size=3, stride=1, padding=1
                ),
                torch.nn.Softmax(dim=1),
            )

        def Subnet_offset_us(ks):
            return torch.nn.Sequential(
                torch.nn.Conv2d(
                    in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1
                ),
                torch.nn.ReLU(inplace=False),
                torch.nn.Conv2d(
                    in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1
                ),
                torch.nn.ReLU(inplace=False),
                torch.nn.Conv2d(
                    in_channels=64, out_channels=ks, kernel_size=3, stride=1, padding=1
                ),
                torch.nn.ReLU(inplace=False),
                torch.nn.Upsample(scale_factor=4, mode="bilinear", align_corners=True),
                torch.nn.Conv2d(
                    in_channels=ks, out_channels=ks, kernel_size=3, stride=1, padding=1
                ),
            )

        def Subnet_weight_us(ks):
            return torch.nn.Sequential(
                torch.nn.Conv2d(
                    in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1
                ),
                torch.nn.ReLU(inplace=False),
                torch.nn.Conv2d(
                    in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1
                ),
                torch.nn.ReLU(inplace=False),
                torch.nn.Conv2d(
                    in_channels=64, out_channels=ks, kernel_size=3, stride=1, padding=1
                ),
                torch.nn.ReLU(inplace=False),
                torch.nn.Upsample(scale_factor=4, mode="bilinear", align_corners=True),
                torch.nn.Conv2d(
                    in_channels=ks, out_channels=ks, kernel_size=3, stride=1, padding=1
                ),
                torch.nn.Softmax(dim=1),
            )

        self.moduleWeight1_ds = Subnet_weight_ds(self.kernel_size**2)
        self.moduleAlpha1_ds = Subnet_offset_ds(self.kernel_size**2)
        self.moduleBeta1_ds = Subnet_offset_ds(self.kernel_size**2)
        self.moduleWeight2_ds = Subnet_weight_ds(self.kernel_size**2)
        self.moduleAlpha2_ds = Subnet_offset_ds(self.kernel_size**2)
        self.moduleBeta2_ds = Subnet_offset_ds(self.kernel_size**2)

        self.moduleWeight1 = Subnet_weight(self.kernel_size**2)
        self.moduleAlpha1 = Subnet_offset(self.kernel_size**2)
        self.moduleBeta1 = Subnet_offset(self.kernel_size**2)
        self.moduleWeight2 = Subnet_weight(self.kernel_size**2)
        self.moduleAlpha2 = Subnet_offset(self.kernel_size**2)
        self.moduleBeta2 = Subnet_offset(self.kernel_size**2)

        self.moduleWeight1_us = Subnet_weight_us(self.kernel_size**2)
        self.moduleAlpha1_us = Subnet_offset_us(self.kernel_size**2)
        self.moduleBeta1_us = Subnet_offset_us(self.kernel_size**2)
        self.moduleWeight2_us = Subnet_weight_us(self.kernel_size**2)
        self.moduleAlpha2_us = Subnet_offset_us(self.kernel_size**2)
        self.moduleBeta2_us = Subnet_offset_us(self.kernel_size**2)

    def forward(self, tensorCombine):
        # Frame 0
        Weight1_ds = self.moduleWeight1_ds(tensorCombine)
        Weight1 = self.moduleWeight1(tensorCombine)
        Weight1_us = self.moduleWeight1_us(tensorCombine)
        Alpha1_ds = self.moduleAlpha1_ds(tensorCombine)
        Alpha1 = self.moduleAlpha1(tensorCombine)
        Alpha1_us = self.moduleAlpha1_us(tensorCombine)
        Beta1_ds = self.moduleBeta1_ds(tensorCombine)
        Beta1 = self.moduleBeta1(tensorCombine)
        Beta1_us = self.moduleBeta1_us(tensorCombine)

        # Frame 2
        Weight2_ds = self.moduleWeight2_ds(tensorCombine)
        Weight2 = self.moduleWeight2(tensorCombine)
        Weight2_us = self.moduleWeight2_us(tensorCombine)
        Alpha2_ds = self.moduleAlpha2_ds(tensorCombine)
        Alpha2 = self.moduleAlpha2(tensorCombine)
        Alpha2_us = self.moduleAlpha2_us(tensorCombine)
        Beta2_ds = self.moduleBeta2_ds(tensorCombine)
        Beta2 = self.moduleBeta2(tensorCombine)
        Beta2_us = self.moduleBeta2_us(tensorCombine)

        return (
            Weight1_ds,
            Alpha1_ds,
            Beta1_ds,
            Weight2_ds,
            Alpha2_ds,
            Beta2_ds,
            Weight1,
            Alpha1,
            Beta1,
            Weight2,
            Alpha2,
            Beta2,
            Weight1_us,
            Alpha1_us,
            Beta1_us,
            Weight2_us,
            Alpha2_us,
            Beta2_us,
        )


class student_STMFNet(torch.nn.Module):
    def __init__(self, args):

        super(student_STMFNet, self).__init__()

        class Metric(torch.nn.Module):
            def __init__(self):
                super(Metric, self).__init__()
                self.paramScale = torch.nn.Parameter(-torch.ones(1, 1, 1, 1))

            def forward(self, tenFirst, tenSecond, tenFlow):
                return self.paramScale * F.l1_loss(
                    input=tenFirst,
                    target=backwarp(tenSecond, tenFlow),
                    reduction="none",
                ).mean(1, True)

        self.args = args
        self.kernel_size = args.kernel_size
        self.kernel_pad = int(((args.kernel_size - 1) * args.dilation) / 2.0)
        self.dilation = args.dilation

        self.feature_extractor = getattr(feature, args.featnet)(
            args.featc, norm_layer=args.featnorm
        )

        self.get_kernel = KernelEstimation(self.kernel_size)

        self.modulePad = torch.nn.ReplicationPad2d(
            [self.kernel_pad, self.kernel_pad, self.kernel_pad, self.kernel_pad]
        )

        self.moduleAdaCoF = adacof.FunctionAdaCoF.apply

        self.gauss_kernel = torch.nn.Parameter(
            gaussian_kernel(5, 0.5).repeat(3, 1, 1, 1), requires_grad=False
        )

        self.upsampler = Upsampler_8tap()

        self.scale_synthesis = MIMOGridNet(
            (6, 6 + 6, 6), (3,), grid_chs=(32, 64, 96), n_row=3, n_col=4, outrow=(1,)
        )

        self.flow_estimator = PWCNet()

        self.softsplat = ModuleSoftsplat(strType="softmax")

        self.metric = Metric()

        self.dyntex_generator = UNet3d_18(bn=args.featnorm)

        # freeze weights of PWCNet if not finetuning it
        if not args.finetune_pwc:
            for param in self.flow_estimator.parameters():
                param.requires_grad = False

    def forward(self, I0, I1, I2, I3, *args):
        h0 = int(list(I1.size())[2])
        w0 = int(list(I1.size())[3])
        h2 = int(list(I2.size())[2])
        w2 = int(list(I2.size())[3])
        if h0 != h2 or w0 != w2:
            sys.exit("Frame sizes do not match")

        h_padded = False
        w_padded = False
        if h0 % 128 != 0:
            pad_h = 128 - (h0 % 128)
            I0 = F.pad(I0, (0, 0, 0, pad_h), mode="reflect")
            I1 = F.pad(I1, (0, 0, 0, pad_h), mode="reflect")
            I2 = F.pad(I2, (0, 0, 0, pad_h), mode="reflect")
            I3 = F.pad(I3, (0, 0, 0, pad_h), mode="reflect")
            h_padded = True

        if w0 % 128 != 0:
            pad_w = 128 - (w0 % 128)
            I0 = F.pad(I0, (0, pad_w, 0, 0), mode="reflect")
            I1 = F.pad(I1, (0, pad_w, 0, 0), mode="reflect")
            I2 = F.pad(I2, (0, pad_w, 0, 0), mode="reflect")
            I3 = F.pad(I3, (0, pad_w, 0, 0), mode="reflect")
            w_padded = True

        feats = self.feature_extractor(moduleNormalize(I1), moduleNormalize(I2))
        kernelest = self.get_kernel(feats)
        Weight1_ds, Alpha1_ds, Beta1_ds, Weight2_ds, Alpha2_ds, Beta2_ds = kernelest[:6]
        Weight1, Alpha1, Beta1, Weight2, Alpha2, Beta2 = kernelest[6:12]
        Weight1_us, Alpha1_us, Beta1_us, Weight2_us, Alpha2_us, Beta2_us = kernelest[
            12:
        ]

        # Original scale
        tensorAdaCoF1 = (
            self.moduleAdaCoF(self.modulePad(I1), Weight1, Alpha1, Beta1, self.dilation)
            * 1.0
        )
        tensorAdaCoF2 = (
            self.moduleAdaCoF(self.modulePad(I2), Weight2, Alpha2, Beta2, self.dilation)
            * 1.0
        )

        # 1/2 downsampled version
        c, h, w = I1.shape[1:]
        p = (self.gauss_kernel.shape[-1] - 1) // 2
        I1_blur = F.conv2d(
            F.pad(I1, pad=(p, p, p, p), mode="reflect"), self.gauss_kernel, groups=c
        )
        I2_blur = F.conv2d(
            F.pad(I2, pad=(p, p, p, p), mode="reflect"), self.gauss_kernel, groups=c
        )
        I1_ds = F.interpolate(
            I1_blur, size=(h // 2, w // 2), mode="bilinear", align_corners=False
        )
        I2_ds = F.interpolate(
            I2_blur, size=(h // 2, w // 2), mode="bilinear", align_corners=False
        )
        tensorAdaCoF1_ds = (
            self.moduleAdaCoF(
                self.modulePad(I1_ds), Weight1_ds, Alpha1_ds, Beta1_ds, self.dilation
            )
            * 1.0
        )
        tensorAdaCoF2_ds = (
            self.moduleAdaCoF(
                self.modulePad(I2_ds), Weight2_ds, Alpha2_ds, Beta2_ds, self.dilation
            )
            * 1.0
        )

        # x2 upsampled version
        I1_us = self.upsampler(I1)
        I2_us = self.upsampler(I2)
        tensorAdaCoF1_us = (
            self.moduleAdaCoF(
                self.modulePad(I1_us), Weight1_us, Alpha1_us, Beta1_us, self.dilation
            )
            * 1.0
        )
        tensorAdaCoF2_us = (
            self.moduleAdaCoF(
                self.modulePad(I2_us), Weight2_us, Alpha2_us, Beta2_us, self.dilation
            )
            * 1.0
        )

        # use softsplat for refinement
        pyramid0, pyramid2 = self.flow_estimator.extract_pyramid(I1, I2)
        flow_0_2 = 20 * self.flow_estimator(I1, I2, pyramid0, pyramid2)
        flow_0_2 = F.interpolate(
            flow_0_2, size=(h, w), mode="bilinear", align_corners=False
        )
        flow_2_0 = 20 * self.flow_estimator(I2, I1, pyramid2, pyramid0)
        flow_2_0 = F.interpolate(
            flow_2_0, size=(h, w), mode="bilinear", align_corners=False
        )
        metric_0_2 = self.metric(I1, I2, flow_0_2)
        metric_2_0 = self.metric(I2, I1, flow_2_0)
        tensorSoftsplat0 = self.softsplat(I1, 0.5 * flow_0_2, metric_0_2)
        tensorSoftsplat2 = self.softsplat(I2, 0.5 * flow_2_0, metric_2_0)

        # synthesize multiple scales
        tensorCombine_us = torch.cat([tensorAdaCoF1_us, tensorAdaCoF2_us], dim=1)
        tensorCombine = torch.cat(
            [tensorAdaCoF1, tensorAdaCoF2, tensorSoftsplat0, tensorSoftsplat2], dim=1
        )
        tensorCombine_ds = torch.cat([tensorAdaCoF1_ds, tensorAdaCoF2_ds], dim=1)
        output_tilde = self.scale_synthesis(
            tensorCombine_us, tensorCombine, tensorCombine_ds
        )[0]

        # generate dynamic texture
        dyntex = self.dyntex_generator(I0, I1, I2, I3, output_tilde)
        output = output_tilde + dyntex

        if h_padded:
            output = output[:, :, 0:h0, :]
        if w_padded:
            output = output[:, :, :, 0:w0]

        if self.training:
            return {"frame1": output}
        else:
            return output


student = student_STMFNet(args)
student.to(device);

### distillation model

In [30]:
args=trainArgs()

softmax_optimiser = nn.Softmax(dim=1)
mse_loss_function = nn.MSELoss()

def my_loss(scores, targets, temperature = 5):
    soft_pred = softmax_optimiser(scores / temperature)
    soft_targets = softmax_optimiser(targets / temperature)
    loss = mse_loss_function(soft_pred, soft_targets)
    return loss

distil_optimizer = optim.Adam(student.parameters(), lr=0.0001)

losses = []

for epoch in range(5):

	running_loss = 0.0
	for i, data in enumerate(train_loader, 1):

		inputs, labels = data[0].to(device), data[1].to(device)

		targets = teacher(inputs)
		scores = student(inputs)
		loss = my_loss(scores, targets, temperature = 2)
		distil_optimizer.zero_grad()
		loss.backward()
		distil_optimizer.step()

		# print statistics
		running_loss += loss.item()
		if i % 60 == 59:    # print every 60 mini-batches
			print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 60:.3f}')
			running_loss = 0.0
		
	print('appending loss: ', loss.item())
	losses.append(loss.item())

TypeError: forward() missing 3 required positional arguments: 'I1', 'I2', and 'I3'

In [31]:
args = trainArgs()
import losses
loss = losses.DistillationLoss(args)

start_epoch = 0
if args.load is not None:
    checkpoint = torch.load(args.load)
    student.load_state_dict(checkpoint["state_dict"])
    start_epoch = checkpoint["epoch"]

my_trainer = Distiller(args, train_loader, valid_loader, student, loss, start_epoch)

now = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S")
with open(join(args.out_dir, "config.txt"), "a") as f:
    f.write(now + "\n\n")
    for arg in vars(args):
        f.write("{}: {}\n".format(arg, getattr(args, arg)))
    f.write("\n")

while not my_trainer.terminate():
    my_trainer.train()
    my_trainer.save_checkpoint()
    my_trainer.validate()


1.000 * Lap
loss  1 :  tensor(116355.4453, device='cuda:0', grad_fn=<AddBackward0>)
loss  2 :  tensor(209200.2500, device='cuda:0', grad_fn=<AddBackward0>)
loss  3 :  tensor(106750.4141, device='cuda:0', grad_fn=<AddBackward0>)
loss  4 :  tensor(77347.9844, device='cuda:0', grad_fn=<AddBackward0>)
loss  5 :  tensor(49366.7305, device='cuda:0', grad_fn=<AddBackward0>)
loss  6 :  tensor(36558.9688, device='cuda:0', grad_fn=<AddBackward0>)
loss  7 :  tensor(28091.4160, device='cuda:0', grad_fn=<AddBackward0>)
loss  8 :  tensor(27802.7461, device='cuda:0', grad_fn=<AddBackward0>)
loss  9 :  tensor(39013.3984, device='cuda:0', grad_fn=<AddBackward0>)
loss  10 :  tensor(31195.1328, device='cuda:0', grad_fn=<AddBackward0>)
loss  11 :  tensor(23082.6484, device='cuda:0', grad_fn=<AddBackward0>)
loss  12 :  tensor(28373.3301, device='cuda:0', grad_fn=<AddBackward0>)
loss  13 :  tensor(26240.2266, device='cuda:0', grad_fn=<AddBackward0>)
loss  14 :  tensor(22674.9688, device='cuda:0', grad_fn=<A

KeyboardInterrupt: 

In [34]:
loss

Loss(
  (loss_module): ModuleList(
    (0): LaplacianLoss(
      (criterion): L1Loss()
      (lap): LaplacianPyramid(
        (gaussian_conv): GaussianConv()
      )
    )
  )
)

In [21]:
args=testArgs()
print("Testing on dataset: ", args.dataset)
test_dir = os.path.join(args.out_dir, args.dataset)
if args.dataset.split("_")[0] in ["VFITex", "Ucf101", "Davis90"]:
    db_folder = args.dataset.split("_")[0].lower()
else:
    db_folder = args.dataset.lower()
test_db = getattr(testsets, args.dataset)(os.path.join(args.data_dir, db_folder))
if not os.path.exists(test_dir):
    os.mkdir(test_dir)

test_db.eval(teacher, metrics=args.metrics, output_dir=test_dir)

Testing on dataset:  Ucf101_quintuplet




0               -- {'PSNR': 26.01, 'SSIM': 0.84}
1               -- {'PSNR': 41.335, 'SSIM': 0.997}
10              -- {'PSNR': 33.793, 'SSIM': 0.991}
11              -- {'PSNR': 31.247, 'SSIM': 0.973}
12              -- {'PSNR': 34.663, 'SSIM': 0.985}
13              -- {'PSNR': 35.933, 'SSIM': 0.988}
14              -- {'PSNR': 29.769, 'SSIM': 0.955}
15              -- {'PSNR': 29.514, 'SSIM': 0.974}
16              -- {'PSNR': 33.739, 'SSIM': 0.99}
17              -- {'PSNR': 25.378, 'SSIM': 0.946}
18              -- {'PSNR': 33.509, 'SSIM': 0.993}
19              -- {'PSNR': 38.037, 'SSIM': 0.994}
2               -- {'PSNR': 39.235, 'SSIM': 0.996}
20              -- {'PSNR': 37.453, 'SSIM': 0.992}
21              -- {'PSNR': 31.144, 'SSIM': 0.983}
22              -- {'PSNR': 35.097, 'SSIM': 0.985}
23              -- {'PSNR': 28.47, 'SSIM': 0.971}
24              -- {'PSNR': 40.438, 'SSIM': 0.998}
25              -- {'PSNR': 29.044, 'SSIM': 0.943}
26              -- {'PSNR': 38.616,

In [10]:
from torchvision import transforms
from PIL import Image
from os.path import join, exists
import utility
from torchvision.utils import save_image as imwrite

db_dir = './tests/'
transform = transforms.Compose([transforms.ToTensor()])

im_list = os.listdir(db_dir)

input1_list = []
input3_list = []
input5_list = []
input7_list = []
gt_list = []
for item in im_list:
    input1_list.append(
        transform(Image.open(join(db_dir, item, "frame0.png")))
        .cuda()
        .unsqueeze(0)
    )
    input3_list.append(
        transform(Image.open(join(db_dir, item, "frame1.png")))
        .cuda()
        .unsqueeze(0)
    )
    input5_list.append(
        transform(Image.open(join(db_dir, item, "frame2.png")))
        .cuda()
        .unsqueeze(0)
    )
    input7_list.append(
        transform(Image.open(join(db_dir, item, "frame3.png")))
        .cuda()
        .unsqueeze(0)
    )
    gt_list.append(
        transform(Image.open(join(db_dir, item, "framet.png")))
        .cuda()
        .unsqueeze(0)
    )

# def eval(model, , output_dir=None, output_name="output.png"):
# model.eval()

output_dir = "./tests/"
output_name = "output.png"


# results_dict = {k: [] for k in metrics}

# logfile = open(join(output_dir, "results.txt"), "a")

for idx in range(len(im_list)):
    if not exists(join(output_dir, im_list[idx])):
        os.makedirs(join(output_dir, im_list[idx]))

    with torch.no_grad():
        out = model(
            input1_list[idx],
            input3_list[idx],
            input5_list[idx],
            input7_list[idx],
        )
    gt = gt_list[idx]


    imwrite(out, join(output_dir, im_list[idx], output_name), range=(0, 1))

#     msg = (
#         "{:<15s} -- {}".format(
#             im_list[idx],
#             {k: round(results_dict[k][-1], 3) for k in metrics},
#         )
#         + "\n"
#     )
#     print(msg, end="")
#     logfile.write(msg)

# msg = (
#     "{:<15s} -- {}".format(
#         "Average", {k: round(np.mean(results_dict[k]), 3) for k in metrics}
#     )
#     + "\n\n"
# )
# print(msg, end="")
# logfile.write(msg)
# logfile.close()
torch.cuda.empty_cache()