In [None]:
from models import *
import math
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from model.pytorch_msssim import ssim_matlab
import numpy as np
import cv2
import math
import pickle
import ossaudiodev
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image

In [None]:
DEVICE = torch.device("cuda")
teacher_model = EMA_VFI_Model()
teacher_model.load_model('ours')
teacher_model.eval()
teacher_model.net.to(DEVICE)

In [None]:
def evaluate(model, list_path, path="/content/Project1/vimeo_triplet/"):
    model.eval()
    f = open(list_path, 'r')
    psnr_list = []
    ssim_list = []
    cnt = 0
    for i in f:
        name = str(i).strip()
        if(len(name) <= 1):
            continue
        # print(path + 'sequences/' + name + '/im1.png')
        I0 = cv2.imread(path + 'sequences/' + name + '/im1.png')
        I1 = cv2.imread(path + 'sequences/' + name + '/im2.png')
        I2 = cv2.imread(path + 'sequences/' + name + '/im3.png')
        I0 = (torch.tensor(I0.transpose(2, 0, 1)).to(DEVICE) / 255.).unsqueeze(0)
        I2 = (torch.tensor(I2.transpose(2, 0, 1)).to(DEVICE) / 255.).unsqueeze(0)
        mid = model.inference(I0, I2)[0][0]
        ssim = ssim_matlab(torch.tensor(I1.transpose(2, 0, 1)).to(DEVICE).unsqueeze(0) / 255., torch.round(mid * 255).unsqueeze(0) / 255.).detach().cpu().numpy()
        mid = np.round((mid * 255).detach().cpu().numpy()).astype('uint8').transpose(1, 2, 0) / 255.
        I1 = I1 / 255.
        psnr = -10 * math.log10(((I1 - mid) * (I1 - mid)).mean())
        psnr_list.append(psnr)
        ssim_list.append(ssim)
    print("Avg PSNR: {} SSIM: {}".format(np.mean(psnr_list), np.mean(ssim_list)))
    return psnr_list, ssim_list

In [None]:
class VimeoDataset(Dataset):
    def __init__(self, dataset_dir, triplet_list_file, train=True):
        self.dataset_dir = dataset_dir
        self.triplet_list_file = triplet_list_file
        self.train = train
        self.triplets = self._load_triplets()
        self.transform = transforms.Compose([
            transforms.ToTensor()
        ])

    def _load_triplets(self):
        triplets = []
        with open(self.triplet_list_file, 'r') as f:
            for line in f:
                triplets.append(line.strip())
        return triplets

    def __len__(self):
        return len(self.triplets)

    def __getitem__(self, idx):
        triplet_path = self.triplets[idx]
        img1_path = os.path.join(self.dataset_dir, triplet_path, 'im1.png')
        img2_path = os.path.join(self.dataset_dir, triplet_path, 'im2.png') # Ground Truth
        img3_path = os.path.join(self.dataset_dir, triplet_path, 'im3.png')
        # Load images
        img1 = Image.open(img1_path).convert('RGB')
        img2 = Image.open(img2_path).convert('RGB') 
        img3 = Image.open(img3_path).convert('RGB')
        # Convert to tensor first
        img1 = self.transform(img1)
        img2 = self.transform(img2)
        img3 = self.transform(img3)
        # Concatenate img1 and img3 for model input
        imgs = torch.cat((img1, img3), dim=0) 
        return {'imgs': imgs, 'gt': img2}

In [None]:
VIMEO_DIR = "vimeo_triplet/sequences"
TRAIN_LIST = "vimeo_triplet/tri_trainlist.txt"
VAL_LIST = "vimeo_triplet/tri_vallist.txt"
TEST_LIST = "vimeo_triplet/tri_testlist.txt"

train_dataset = VimeoDataset(
    dataset_dir=VIMEO_DIR,
    triplet_list_file=TRAIN_LIST,
    train=True
)
train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

In [None]:
def flow_distill(train_loader, path_idx, num_epoch=10, lr=1e-6, flow_distill_weight=0.0005, output_path="/content/drive/MyDrive/EE 641/Project/"):
    os.mkdir(output_path+path_idx+"flow_distill_res")
    student_model = RIFE_Model()
    student_model.load_model('ckpt')
    student_model.flownet.to(DEVICE)
    losses_g, losses_flow = [], []
    for epoch in range(num_epoch):
        student_model.train()
        teacher_model.eval()
        total_loss_g = 0
        total_loss_flow = 0
        loss_g, loss_flow = [], []
        for i, data in enumerate(train_loader):
            imgs = data['imgs'].to(DEVICE, non_blocking=True)
            gt = data['gt'].to(DEVICE, non_blocking=True)
            _, loss_dict = student_model.update_flow_distill(
                imgs,
                gt,
                learning_rate=lr,
                training=True,
                teacher_model=teacher_model,
                flow_distill_weight=flow_distill_weight
            )
            # Append loss
            current_loss_g = loss_dict['loss_l1'] + loss_dict['loss_tea'] + loss_dict['loss_distill'] * 0.01
            current_loss_flow = loss_dict['loss_flow_distill']
            total_loss_g += current_loss_g.item()
            total_loss_flow += current_loss_flow.item()
            loss_g.append(current_loss_g.item())
            loss_flow.append(current_loss_flow.item())
            # Current batch evaluate, save checkpoint
            if (i + 1) % 100 == 0:
                psnr_list, ssim_list = evaluate(student_model, list_path='/content/drive/MyDrive/EE 641/Project/vimeo_triplet/tri_vallist.txt')
            if (i) % 20 == 0:
                print(f"Epoch {epoch} Batch {i+1}/{len(train_loader)}, "
                    f"RIFE Loss: {current_loss_g.item():.4f}, "
                    f"Flow Distill Loss: {current_loss_flow.item():.4f}")
                torch.save(student_model.flownet.state_dict(), f"{output_path}{path_idx}flow_distill_res/epoch_{epoch+1}_batch_{i+1}.pkl")
        # Current epoch evaluate
        avg_loss_g = total_loss_g / len(train_loader)
        avg_loss_flow = total_loss_flow / len(train_loader)
        print(f"Epoch {epoch+1}/{num_epoch} finished. Avg RIFE Loss: {avg_loss_g:.4f}, Avg Flow Distill Loss: {avg_loss_flow:.4f}")
        losses_g.append(loss_g)
        losses_flow.append(loss_flow)
        with open(f'{output_path}{path_idx}flow_distill_res/losses_g.pkl', 'wb') as f:
            pickle.dump(losses_g, f)
        with open(f'{output_path}{path_idx}flow_distill_res/losses_flow.pkl', 'wb') as f:
            pickle.dump(losses_flow, f)
    # Save final checkpoint
    torch.save(student_model.flownet.state_dict(), f"{output_path}{path_idx}flow_distill_res/final.pkl")
    print(f"Final checkpoint saved for Flow Distillation Weight: {flow_distill_weight}")
    psnr_list, ssim_list = evaluate(student_model, list_path='/content/Project1/vimeo_triplet/tri_testlist.txt')

In [None]:
flow_distill(train_loader, '0005', 20)