In [1]:
import os
import math
from data import testsets
import torch
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
from data.datasets import *
from trainers.distiller import Distiller
from torch.utils.data import DataLoader
from models.stmfnet import STMFNet
from models.student import student_STMFNet
import models
import losses
import datetime
from os.path import join
from torchinfo import summary
import pandas as pd

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
torch.cuda.get_device_name()

'NVIDIA GeForce RTX 3090'

In [2]:
class testArgs:
    gpu_id = 0
    net = 'STMFNet'
    dataset = 'Ucf101_quintuplet'
    metrics = ['PSNR', 'SSIM']
    checkpoint = './train_results/checkpoint/model_epoch008.pth'
    # checkpoint = './models/stmfnet.pth'
    data_dir = 'D:/stmfnet_data'
    out_dir = './tests/results'
    featc = [64, 128, 256, 512]
    featnet = 'UMultiScaleResNext'
    featnorm = 'batch'
    kernel_size = 5
    dilation = 1
    finetune_pwc = False

class trainArgs:
    gpu_id = 0
    net = 'STMFNet'
    data_dir = 'D:/stmfnet_data'
    out_dir = './train_results'
    load = None
    epochs = 70
    batch_size = 2
    loss = "1*Lap"
    patch_size = 256
    lr = 0.001
    lr_decay = 20
    decay_type = 'step'
    gamma = 0.5
    patience = None
    optimizer = 'ADAMax'
    weight_decay = 0
    featc = [32, 64, 96, 128]
    featnet = 'UMultiScaleResNext'
    featnorm = 'batch'
    kernel_size = 5
    dilation = 1
    finetune_pwc = False
    teacher = "STMFNet"
    student = "student_STMFNet"
    temp = 10
    alpha = 0.1
    distill_loss_fn = "KLDivLoss"

args=trainArgs()

In [3]:
student = getattr(models, args.student)(args).cuda()
student.to(device);
print(student)


model_summary = summary(student, [(2, 3, 256, 256), (2, 3, 256, 256), (2, 3, 256, 256), (2, 3, 256, 256)])

with open('./summaries/STUDENT.txt', 'w', encoding="utf-8") as f:
    f.write(str(model_summary))
    f.close()

# print(student)

student_STMFNet(
  (feature_extractor): UMultiScaleResNext(
    (conv1): MultiScaleResNextBlock(
      (resnext_small): ResNextBlock(
        (conv1): Conv2d(6, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
        (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0): Conv2d(6, 16, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (resnext_large): ResNextBlock(
        (conv1): Conv2d

### Weights Comparison Automation:

define args

In [2]:
class trainArgs:
    gpu_id = 0
    net = 'STMFNet'
    data_dir = 'D:/stmfnet_data'
    out_dir = './train_results'
    load = None
    epochs = 70
    batch_size = 2
    loss = "1*Lap"
    patch_size = 256
    lr = 0.001
    lr_decay = 20
    decay_type = 'step'
    gamma = 0.5
    patience = None
    optimizer = 'ADAMax'
    weight_decay = 0
    featc = [64, 128, 256, 512]
    featnet = 'UMultiScaleResNext'
    featnorm = 'batch'
    kernel_size = 5
    dilation = 1
    finetune_pwc = False
    temp = 10
    alpha = 0.3
    distill_loss_fn = 'KLDivLoss'

args=trainArgs()

load models:

In [3]:
def to_device(data, device):
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

teacher = to_device(STMFNet(args), device)
teacher.to(device)
t_checkpoint = torch.load('./models/stmfnet.pth')
teacher.load_state_dict(t_checkpoint['state_dict'])


model = to_device(STMFNet(args), device)
model.to(device)
m_checkpoint = torch.load("./train_results/checkpoint/model_epoch005.pth")
model.load_state_dict(m_checkpoint['state_dict'])

<All keys matched successfully>

calculate number of zeroes in each layer:

In [4]:
teacher_weights = [w for name,
                   w in teacher.named_parameters() if "weight" in name]
model_weights = [w for name, w in model.named_parameters() if "weight" in name]
layer_names = [name for name, w in model.named_parameters()
               if "weight" in name and len(w.shape) > 1]
K_parameters = np.array([np.prod(w.shape)
                        for w in model_weights if len(w.shape) > 1])
layer_shapes = [w.shape for w in model_weights if len(w.shape) > 1]
num_zeros = np.array([np.count_nonzero(w.detach().cpu() == 0)
                     for w in model_weights if len(w.shape) > 1])


find compression ratios for each layer:

In [5]:
len(layer_shapes)

295

In [6]:
print(len(layer_names))
layer_names

295


['feature_extractor.conv1.resnext_small.conv1.weight',
 'feature_extractor.conv1.resnext_small.conv2.weight',
 'feature_extractor.conv1.resnext_small.conv3.weight',
 'feature_extractor.conv1.resnext_small.downsample.0.weight',
 'feature_extractor.conv1.resnext_large.conv1.weight',
 'feature_extractor.conv1.resnext_large.conv2.weight',
 'feature_extractor.conv1.resnext_large.conv3.weight',
 'feature_extractor.conv1.resnext_large.downsample.0.weight',
 'feature_extractor.conv1.attention.fc.0.weight',
 'feature_extractor.conv1.attention.fc.2.weight',
 'feature_extractor.conv2.resnext_small.conv1.weight',
 'feature_extractor.conv2.resnext_small.conv2.weight',
 'feature_extractor.conv2.resnext_small.conv3.weight',
 'feature_extractor.conv2.resnext_small.downsample.0.weight',
 'feature_extractor.conv2.resnext_large.conv1.weight',
 'feature_extractor.conv2.resnext_large.conv2.weight',
 'feature_extractor.conv2.resnext_large.conv3.weight',
 'feature_extractor.conv2.resnext_large.downsample.0.w

In [7]:
compression_ratios = 1 - (num_zeros / K_parameters)

In [8]:
layer_shapes_df = pd.DataFrame(layer_shapes, columns=['C_out','C_in','q1','q2','q3']).fillna(0)
# df1 = pd.DataFrame(layer_shapes, columns=['layer_shapes'])
df = pd.DataFrame(num_zeros, columns=['num_zeros'])
df['compression_ratios'] = compression_ratios
df['layer_shapes'] = layer_shapes
# df['parameters'] = K_parameters
data = pd.concat([layer_shapes_df, df], axis=1)


data['new_C_in'] = round(data['C_in'] * data['compression_ratios'])
data['difference'] = (abs(data['C_in'] - data['new_C_in'])) #  / ((data['C_in'] + data['new_C_in'])/2)

print(max(data['difference']))
# data = data[data['difference'] > 0.3]
data['layer_names'] = layer_names
data.to_csv('./layer_shapes.csv')

data


648.0


Unnamed: 0,C_out,C_in,q1,q2,q3,num_zeros,compression_ratios,layer_shapes,new_C_in,difference,layer_names
0,64,6,1.0,1.0,0.0,41,0.893229,"(64, 6, 1, 1)",5.0,1.0,feature_extractor.conv1.resnext_small.conv1.we...
1,64,2,3.0,3.0,0.0,155,0.865451,"(64, 2, 3, 3)",2.0,0.0,feature_extractor.conv1.resnext_small.conv2.we...
2,32,64,1.0,1.0,0.0,437,0.786621,"(32, 64, 1, 1)",50.0,14.0,feature_extractor.conv1.resnext_small.conv3.we...
3,32,6,1.0,1.0,0.0,18,0.906250,"(32, 6, 1, 1)",5.0,1.0,feature_extractor.conv1.resnext_small.downsamp...
4,64,6,1.0,1.0,0.0,36,0.906250,"(64, 6, 1, 1)",5.0,1.0,feature_extractor.conv1.resnext_large.conv1.we...
...,...,...,...,...,...,...,...,...,...,...,...
290,32,32,1.0,1.0,1.0,283,0.723633,"(32, 32, 1, 1, 1)",23.0,9.0,dyntex_generator.decoder.3.conv.1.attn_layer.0...
291,64,32,3.0,4.0,4.0,24996,0.745728,"(64, 32, 3, 4, 4)",24.0,8.0,dyntex_generator.decoder.4.upconv.0.weight
292,32,32,1.0,1.0,1.0,188,0.816406,"(32, 32, 1, 1, 1)",26.0,6.0,dyntex_generator.decoder.4.upconv.1.attn_layer...
293,32,160,1.0,1.0,0.0,1711,0.665820,"(32, 160, 1, 1)",107.0,53.0,dyntex_generator.feature_fuse.0.weight


model summaries:

In [9]:
teacher_summary = summary(teacher, [(2, 3, 256, 256), (2, 3, 256, 256), (2, 3, 256, 256), (2, 3, 256, 256)])
model_summary = summary(model, [(2, 3, 256, 256), (2, 3, 256, 256), (2, 3, 256, 256), (2, 3, 256, 256)])

with open('./summaries/model.txt', 'w', encoding="utf-8") as f:
    f.write(str(model_summary))
    f.close()
with open('./summaries/teacher.txt', 'w', encoding="utf-8") as f:
    f.write(str(teacher_summary))
    f.close()

### Knowledge Distillation:

In [2]:
class testArgs:
    gpu_id = 0
    net = 'STMFNet'
    dataset = 'Ucf101_quintuplet'
    metrics = ['PSNR', 'SSIM']
    checkpoint = './train_results/checkpoint/model_epoch008.pth'
    # checkpoint = './models/stmfnet.pth'
    data_dir = 'D:/stmfnet_data'
    out_dir = './tests/results'
    featc = [64, 128, 256, 512]
    featnet = 'UMultiScaleResNext'
    featnorm = 'batch'
    kernel_size = 5
    dilation = 1
    finetune_pwc = False

class trainArgs:
    gpu_id = 0
    net = 'STMFNet'
    data_dir = 'D:/stmfnet_data'
    out_dir = './train_results'
    load = None
    epochs = 70
    batch_size = 2
    loss = "1*Lap"
    patch_size = 256
    lr = 0.001
    lr_decay = 20
    decay_type = 'step'
    gamma = 0.5
    patience = None
    optimizer = 'ADAMax'
    weight_decay = 0
    featc = [64, 128, 256, 512]
    featnet = 'UMultiScaleResNext'
    featnorm = 'batch'
    kernel_size = 5
    dilation = 1
    finetune_pwc = False
    teacher = "STMFNet"
    student = "student_STMFNet"
    temp = 10
    alpha = 0.1
    distill_loss_fn = "KLDivLoss"

args=trainArgs()

### import data

In [3]:
torch.cuda.set_device(args.gpu_id)

# training sets
vimeo90k_train = Vimeo90k_quintuplet(
    join(args.data_dir, "vimeo_septuplet"),
    train=True,
    crop_sz=(args.patch_size, args.patch_size),
)
bvidvc_train = BVIDVC_quintuplet(
    join(args.data_dir, "bvidvc"), crop_sz=(args.patch_size, args.patch_size)
)

# validation set
vimeo90k_valid = Vimeo90k_quintuplet(
    join(args.data_dir, "vimeo_septuplet"),
    train=False,
    crop_sz=(args.patch_size, args.patch_size),
    augment_s=False,
    augment_t=False,
)

datasets_train = [bvidvc_train]
train_sampler = Sampler(datasets_train, iter=True)

# data loaders
train_loader = DataLoader(
    dataset=train_sampler, batch_size=args.batch_size, shuffle=True, num_workers=0
)
valid_loader = DataLoader(
    dataset=vimeo90k_valid, batch_size=args.batch_size, num_workers=0
)

### teacher model

In [4]:
# Load the model

torch.cuda.set_device(args.gpu_id)

if not os.path.exists(args.out_dir):
    os.mkdir(args.out_dir)


teacher = getattr(models, args.teacher)(args).cuda()
def load_model(filepath):

    checkpoint = torch.load(filepath)
    model = STMFNet(args).cuda()
    model.load_state_dict(checkpoint['state_dict'])
    
    return model

model = load_model("./models/stmfnet.pth")

teacher = STMFNet(args).cuda()


### student model

In [5]:
student = getattr(models, args.student)(args).cuda()
student.to(device);
print(student)

student_STMFNet(
  (feature_extractor): UMultiScaleResNext(
    (conv1): MultiScaleResNextBlock(
      (resnext_small): ResNextBlock(
        (conv1): Conv2d(6, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0): Conv2d(6, 32, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (resnext_large): ResNextBlock(
        (conv1): Conv2d

### distillation model

In [14]:
# args=trainArgs()

# softmax_optimiser = nn.Softmax(dim=1)
# mse_loss_function = nn.MSELoss()

# def my_loss(scores, targets, temperature = 5):
#     soft_pred = softmax_optimiser(scores / temperature)
#     soft_targets = softmax_optimiser(targets / temperature)
#     loss = mse_loss_function(soft_pred, soft_targets)
#     return loss

# distil_optimizer = optim.Adam(student.parameters(), lr=0.0001)

# losses = []

# for epoch in range(5):

# 	running_loss = 0.0
# 	for i, data in enumerate(train_loader, 1):

# 		inputs, labels = data[0].to(device), data[1].to(device)

# 		targets = teacher(inputs)
# 		scores = student(inputs)
# 		loss = my_loss(scores, targets, temperature = 2)
# 		distil_optimizer.zero_grad()
# 		loss.backward()
# 		distil_optimizer.step()

# 		# print statistics
# 		running_loss += loss.item()
# 		if i % 60 == 59:    # print every 60 mini-batches
# 			print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 60:.3f}')
# 			running_loss = 0.0
		
# 	print('appending loss: ', loss.item())
# 	losses.append(loss.item())

In [15]:
# from torchinfo import summary


# student.parameters

In [16]:
args = trainArgs()
args.loss = "1*Lap"

import losses
loss = losses.DistillationLoss(args)

start_epoch = 0
# if args.load is not None:
#     checkpoint = torch.load(args.load)
#     student.load_state_dict(checkpoint["state_dict"])
#     start_epoch = checkpoint["epoch"]

distill_optimizer = optim.Adam(student.parameters(), lr=0.0001)
my_trainer = Distiller(args, train_loader, valid_loader, student, teacher, loss, start_epoch)

# now = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S")
# with open(join(args.out_dir, "config.txt"), "a") as f:
#     f.write(now + "\n\n")
#     for arg in vars(args):
#         f.write("{}: {}\n".format(arg, getattr(args, arg)))
#     f.write("\n")

# while not my_trainer.terminate():
#     my_trainer.train()
#     my_trainer.save_checkpoint()
#     my_trainer.validate()


1.000 * Lap

Args: {'betas': (0.9, 0.999), 'eps': 1e-08, 'lr': 0.001, 'weight_decay': 0} 



In [17]:
# args=testArgs()

# def to_device(data, device):
#     if isinstance(data, (list,tuple)):
#         return [to_device(x, device) for x in data]
#     return data.to(device, non_blocking=True)

# teacher = to_device(STMFNet(args), device)
# teacher.to(device)
# checkpoint = torch.load(args.checkpoint)
# teacher.load_state_dict(checkpoint['state_dict'])

# print("Testing on dataset: ", args.dataset)
# test_dir = os.path.join(args.out_dir, args.dataset)
# if args.dataset.split("_")[0] in ["VFITex", "Ucf101", "Davis90"]:
#     db_folder = args.dataset.split("_")[0].lower()
# else:
#     db_folder = args.dataset.lower()
# test_db = getattr(testsets, args.dataset)(os.path.join(args.data_dir, db_folder))
# if not os.path.exists(test_dir):
#     os.mkdir(test_dir)

# test_db.eval(teacher, metrics=args.metrics, output_dir=test_dir)

In [18]:
from torchvision import transforms
from PIL import Image
from os.path import join, exists
import utility
from torchvision.utils import save_image as imwrite

db_dir = 'D:/stmfnet_data/ucf101'
transform = transforms.Compose([transforms.ToTensor()])

im_list = os.listdir(db_dir)

input1_list = []
input3_list = []
input5_list = []
input7_list = []
gt_list = []
for item in im_list:
    input1_list.append(
        transform(Image.open(join(db_dir, item, "frame0.png")))
        .cuda()
        .unsqueeze(0)
    )
    input3_list.append(
        transform(Image.open(join(db_dir, item, "frame1.png")))
        .cuda()
        .unsqueeze(0)
    )
    input5_list.append(
        transform(Image.open(join(db_dir, item, "frame2.png")))
        .cuda()
        .unsqueeze(0)
    )
    input7_list.append(
        transform(Image.open(join(db_dir, item, "frame3.png")))
        .cuda()
        .unsqueeze(0)
    )
    gt_list.append(
        transform(Image.open(join(db_dir, item, "framet.png")))
        .cuda()
        .unsqueeze(0)
    )

# def eval(model, , output_dir=None, output_name="output.png"):
# model.eval()

output_dir = "./tests/"
output_name = "output.png"


# results_dict = {k: [] for k in metrics}

# logfile = open(join(output_dir, "results.txt"), "a")

for idx in range(len(im_list)):
    if not exists(join(output_dir, im_list[idx])):
        os.makedirs(join(output_dir, im_list[idx]))

    with torch.no_grad():
        out = teacher(
            input1_list[idx],
            input3_list[idx],
            input5_list[idx],
            input7_list[idx],
        )
    gt = gt_list[idx]


    imwrite(out, join(output_dir, im_list[idx], output_name), range=(0, 1))

#     msg = (
#         "{:<15s} -- {}".format(
#             im_list[idx],
#             {k: round(results_dict[k][-1], 3) for k in metrics},
#         )
#         + "\n"
#     )
#     print(msg, end="")
#     logfile.write(msg)

# msg = (
#     "{:<15s} -- {}".format(
#         "Average", {k: round(np.mean(results_dict[k]), 3) for k in metrics}
#     )
#     + "\n\n"
# )
# print(msg, end="")
# logfile.write(msg)
# logfile.close()
torch.cuda.empty_cache()

output shape:  torch.Size([2, 3, 256, 256])
student_STMFNet(
  (feature_extractor): UMultiScaleResNext(
    (conv1): MultiScaleResNextBlock(
      (resnext_small): ResNextBlock(
        (conv1): Conv2d(6, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0): Conv2d(6, 32, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (resnext_l