In [1]:
import os

from data import testsets
import torch
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

from data.datasets import *
from trainers.distiller import Distiller
from torch.utils.data import DataLoader
import models
import losses
import datetime
from os.path import join

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
torch.cuda.get_device_name()

'NVIDIA GeForce RTX 3090'

In [2]:
class testArgs:
    gpu_id = 0
    net = 'STMFNet'
    dataset = 'Ucf101_quintuplet'
    metrics = ['PSNR', 'SSIM']
    checkpoint = './models/stmfnet.pth'
    data_dir = 'D:/stmfnet_data'
    out_dir = './tests'
    featc = [64, 128, 256, 512]
    featnet = 'UMultiScaleResNext'
    featnorm = 'batch'
    kernel_size = 5
    dilation = 1
    finetune_pwc = False

class trainArgs:
    gpu_id = 0
    net = 'STMFNet'
    data_dir = 'D:/stmfnet_data'
    out_dir = './train_results'
    load = None
    epochs = 70
    batch_size = 2
    loss = "1*Lap"
    patch_size = 256
    lr = 0.001
    lr_decay = 20
    decay_type = 'step'
    gamma = 0.5
    patience = None
    optimizer = 'ADAMax'
    weight_decay = 0
    featc = [64, 128, 256, 512]
    featnet = 'UMultiScaleResNext'
    featnorm = 'batch'
    kernel_size = 5
    dilation = 1
    finetune_pwc = False
    teacher = "STMFNet"
    student = "student_STMFNet"

args=trainArgs()

### import data

In [3]:
torch.cuda.set_device(args.gpu_id)

# training sets
vimeo90k_train = Vimeo90k_quintuplet(
    join(args.data_dir, "vimeo_septuplet"),
    train=True,
    crop_sz=(args.patch_size, args.patch_size),
)
bvidvc_train = BVIDVC_quintuplet(
    join(args.data_dir, "bvidvc"), crop_sz=(args.patch_size, args.patch_size)
)

# validation set
vimeo90k_valid = Vimeo90k_quintuplet(
    join(args.data_dir, "vimeo_septuplet"),
    train=False,
    crop_sz=(args.patch_size, args.patch_size),
    augment_s=False,
    augment_t=False,
)

datasets_train = [vimeo90k_train, bvidvc_train]
train_sampler = Sampler(datasets_train, iter=True)

# data loaders
train_loader = DataLoader(
    dataset=train_sampler, batch_size=args.batch_size, shuffle=True, num_workers=0
)
valid_loader = DataLoader(
    dataset=vimeo90k_valid, batch_size=args.batch_size, num_workers=0
)

### teacher model

In [4]:
# Load the model

torch.cuda.set_device(args.gpu_id)

if not os.path.exists(args.out_dir):
    os.mkdir(args.out_dir)


teacher = getattr(models, args.teacher)(args).cuda()
# def load_model(filepath):

#     checkpoint = torch.load(filepath)
#     model = STMFNet(args).cuda()
#     model.load_state_dict(checkpoint['state_dict'])
    
#     return model

# model = load_model("./models/stmfnet.pth")



### student model

In [5]:
student = getattr(models, args.student)(args).cuda()
student.to(device);

### distillation model

In [6]:
# args=trainArgs()

# softmax_optimiser = nn.Softmax(dim=1)
# mse_loss_function = nn.MSELoss()

# def my_loss(scores, targets, temperature = 5):
#     soft_pred = softmax_optimiser(scores / temperature)
#     soft_targets = softmax_optimiser(targets / temperature)
#     loss = mse_loss_function(soft_pred, soft_targets)
#     return loss

# distil_optimizer = optim.Adam(student.parameters(), lr=0.0001)

# losses = []

# for epoch in range(5):

# 	running_loss = 0.0
# 	for i, data in enumerate(train_loader, 1):

# 		inputs, labels = data[0].to(device), data[1].to(device)

# 		targets = teacher(inputs)
# 		scores = student(inputs)
# 		loss = my_loss(scores, targets, temperature = 2)
# 		distil_optimizer.zero_grad()
# 		loss.backward()
# 		distil_optimizer.step()

# 		# print statistics
# 		running_loss += loss.item()
# 		if i % 60 == 59:    # print every 60 mini-batches
# 			print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 60:.3f}')
# 			running_loss = 0.0
		
# 	print('appending loss: ', loss.item())
# 	losses.append(loss.item())

In [13]:
# from torchinfo import summary


student.parameters

<bound method Module.parameters of student_STMFNet(
  (feature_extractor): UMultiScaleResNext(
    (conv1): MultiScaleResNextBlock(
      (resnext_small): ResNextBlock(
        (conv1): Conv2d(6, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0): Conv2d(6, 32, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (resnext_large): Re

In [9]:
dir(student)

['T_destination',
 '__annotations__',
 '__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattr__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_apply',
 '_backward_hooks',
 '_buffers',
 '_call_impl',
 '_forward_hooks',
 '_forward_pre_hooks',
 '_get_backward_hooks',
 '_get_name',
 '_is_full_backward_hook',
 '_load_from_state_dict',
 '_load_state_dict_pre_hooks',
 '_maybe_warn_non_full_backward_hook',
 '_modules',
 '_named_members',
 '_non_persistent_buffers_set',
 '_parameters',
 '_register_load_state_dict_pre_hook',
 '_register_state_dict_hook',
 '_replicate_for_data_parallel',
 '_save_to_state_dict',
 '_slow_forward',
 '_state_dict_hooks',
 '_version',
 'add_module',
 'apply',


In [7]:
args = trainArgs()
args.loss = "1*Lap"

import losses
loss = losses.DistillationLoss(args)

start_epoch = 0
# if args.load is not None:
#     checkpoint = torch.load(args.load)
#     student.load_state_dict(checkpoint["state_dict"])
#     start_epoch = checkpoint["epoch"]

distill_optimizer = optim.Adam(student.parameters(), lr=0.0001)
my_trainer = Distiller(args, train_loader, valid_loader, student, teacher, loss, distill_optimizer, start_epoch)

now = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S")
with open(join(args.out_dir, "config.txt"), "a") as f:
    f.write(now + "\n\n")
    for arg in vars(args):
        f.write("{}: {}\n".format(arg, getattr(args, arg)))
    f.write("\n")

while not my_trainer.terminate():
    my_trainer.train()
    my_trainer.save_checkpoint()
    my_trainer.validate()


1.000 * Lap
tensor(4184.4849, device='cuda:0', grad_fn=<AddBackward0>)
tensor(3466.2898, device='cuda:0', grad_fn=<AddBackward0>)
tensor(3409.7297, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2822.0674, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2794.1367, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2594.4651, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2780.1660, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2150.9905, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2675.2241, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2651.7588, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2346.7141, device='cuda:0', grad_fn=<AddBackward0>)


KeyboardInterrupt: 

In [None]:
loss

Loss(
  (loss_module): ModuleList(
    (0): LaplacianLoss(
      (criterion): L1Loss()
      (lap): LaplacianPyramid(
        (gaussian_conv): GaussianConv()
      )
    )
  )
)

In [None]:
args=testArgs()
print("Testing on dataset: ", args.dataset)
test_dir = os.path.join(args.out_dir, args.dataset)
if args.dataset.split("_")[0] in ["VFITex", "Ucf101", "Davis90"]:
    db_folder = args.dataset.split("_")[0].lower()
else:
    db_folder = args.dataset.lower()
test_db = getattr(testsets, args.dataset)(os.path.join(args.data_dir, db_folder))
if not os.path.exists(test_dir):
    os.mkdir(test_dir)

test_db.eval(teacher, metrics=args.metrics, output_dir=test_dir)

Testing on dataset:  Ucf101_quintuplet




0               -- {'PSNR': 26.01, 'SSIM': 0.84}
1               -- {'PSNR': 41.335, 'SSIM': 0.997}
10              -- {'PSNR': 33.793, 'SSIM': 0.991}
11              -- {'PSNR': 31.247, 'SSIM': 0.973}
12              -- {'PSNR': 34.663, 'SSIM': 0.985}
13              -- {'PSNR': 35.933, 'SSIM': 0.988}
14              -- {'PSNR': 29.769, 'SSIM': 0.955}
15              -- {'PSNR': 29.514, 'SSIM': 0.974}
16              -- {'PSNR': 33.739, 'SSIM': 0.99}
17              -- {'PSNR': 25.378, 'SSIM': 0.946}
18              -- {'PSNR': 33.509, 'SSIM': 0.993}
19              -- {'PSNR': 38.037, 'SSIM': 0.994}
2               -- {'PSNR': 39.235, 'SSIM': 0.996}
20              -- {'PSNR': 37.453, 'SSIM': 0.992}
21              -- {'PSNR': 31.144, 'SSIM': 0.983}
22              -- {'PSNR': 35.097, 'SSIM': 0.985}
23              -- {'PSNR': 28.47, 'SSIM': 0.971}
24              -- {'PSNR': 40.438, 'SSIM': 0.998}
25              -- {'PSNR': 29.044, 'SSIM': 0.943}
26              -- {'PSNR': 38.616,

In [None]:
from torchvision import transforms
from PIL import Image
from os.path import join, exists
import utility
from torchvision.utils import save_image as imwrite

db_dir = './tests/'
transform = transforms.Compose([transforms.ToTensor()])

im_list = os.listdir(db_dir)

input1_list = []
input3_list = []
input5_list = []
input7_list = []
gt_list = []
for item in im_list:
    input1_list.append(
        transform(Image.open(join(db_dir, item, "frame0.png")))
        .cuda()
        .unsqueeze(0)
    )
    input3_list.append(
        transform(Image.open(join(db_dir, item, "frame1.png")))
        .cuda()
        .unsqueeze(0)
    )
    input5_list.append(
        transform(Image.open(join(db_dir, item, "frame2.png")))
        .cuda()
        .unsqueeze(0)
    )
    input7_list.append(
        transform(Image.open(join(db_dir, item, "frame3.png")))
        .cuda()
        .unsqueeze(0)
    )
    gt_list.append(
        transform(Image.open(join(db_dir, item, "framet.png")))
        .cuda()
        .unsqueeze(0)
    )

# def eval(model, , output_dir=None, output_name="output.png"):
# model.eval()

output_dir = "./tests/"
output_name = "output.png"


# results_dict = {k: [] for k in metrics}

# logfile = open(join(output_dir, "results.txt"), "a")

for idx in range(len(im_list)):
    if not exists(join(output_dir, im_list[idx])):
        os.makedirs(join(output_dir, im_list[idx]))

    with torch.no_grad():
        out = model(
            input1_list[idx],
            input3_list[idx],
            input5_list[idx],
            input7_list[idx],
        )
    gt = gt_list[idx]


    imwrite(out, join(output_dir, im_list[idx], output_name), range=(0, 1))

#     msg = (
#         "{:<15s} -- {}".format(
#             im_list[idx],
#             {k: round(results_dict[k][-1], 3) for k in metrics},
#         )
#         + "\n"
#     )
#     print(msg, end="")
#     logfile.write(msg)

# msg = (
#     "{:<15s} -- {}".format(
#         "Average", {k: round(np.mean(results_dict[k]), 3) for k in metrics}
#     )
#     + "\n\n"
# )
# print(msg, end="")
# logfile.write(msg)
# logfile.close()
torch.cuda.empty_cache()