In [1]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import cv2
import os 
EPOCHS = 10
ckpt_dir = './ckpts/'
batch_size, height, width = 4,128,128
in_flows = 20
device = 'cuda'
starting_epoch = 0

In [2]:
from model import MPI_Net
model = MPI_Net(input_channels= in_flows*2,num_outputs=in_flows*2, ngf=32).train().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, betas = [0.9,0.99])
ckpts = os.listdir(ckpt_dir)
if ckpts:
    ckpts = sorted(ckpts, key = lambda x : x.split('.')[0].split('_')[1]) #sort
    latest = ckpts[-1]
    state_dict = torch.load(os.path.join(ckpt_dir,latest))
    model.load_state_dict(state_dict['model'])
    starting_epoch = state_dict['epoch'] + 1
    optimizer.load_state_dict(state_dict['optimizer'])
    print('loaded weights from previous session')
    print(f'starting from epoch {starting_epoch}')
    print('learning_rate',optimizer.param_groups[0]['lr'])
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total Trainable Parameters: {total_trainable_params}")

Total Trainable Parameters: 16929120


In [3]:
from torch.utils.data import Dataset, DataLoader
import os
import random
import numpy as np
class DeepStab_Synthetic(Dataset):
    def __init__(self, txt_path = './trainlist.txt'):
        with open(txt_path, 'r') as f:
            self.trainlist = f.read().splitlines()
    
    def __len__(self):
        return len(self.trainlist)
    
    def __getitem__(self,idx):
        sample = self.trainlist[idx]
        path = sample.split(',')[0]
        idx = int(sample.split(',')[1])
        flow_file = np.load(path, mmap_mode='r')
        sequence = flow_file[idx: idx + in_flows,...]
        sequence = np.concatenate(sequence,axis = -1)
        seq = torch.from_numpy(sequence).permute(2,0,1).float()
        return seq


train_ds = DeepStab_Synthetic()
train_loader = DataLoader(train_ds, batch_size=batch_size,shuffle = True)

In [4]:
from torch.utils.tensorboard import SummaryWriter
# default `log_dir` is "runs" - we'll be more specific here
writer = SummaryWriter('runs/')

In [5]:
#losses
def motion_loss(net_out, net_in):
    '''
    input:
        net_out/W: networks output warp fields W torch.Size([1, 160, 128,128])
        net_in/F: networks input PCA Filled fields F torch.Size([1, 160, 128,128])
    '''
    epsilon = 1e-8
    device = net_out.device  # Get the device of net_out tensor
    b, c, h, w = net_out.shape
    dx = net_out[:, ::2, :, :]
    dy = net_out[:, 1::2, :, :]
    W = torch.stack([dx, dy], dim=-1).to(device)  # Move to the same device
    dx = net_in[:, ::2, :, :]
    dy = net_in[:, 1::2, :, :]
    F = torch.stack([dx, dy], dim=-1).to(device)  # Move to the same device
    pixel_trajectories = torch.cumsum(F, dim=1)
    smooth_trajectories = pixel_trajectories + W
    p_hat = pixel_trajectories[:,:-1, ...] + W[:,:-1, ...]
    q_hat = pixel_trajectories[:,1:, ...] + W[:,1:, ...]
    smoothness = torch.mean(torch.sum(torch.sqrt(torch.sum(torch.pow(p_hat - q_hat + epsilon, 2), dim=-1)), dim = (1,2,3)))
    similarity = torch.mean(torch.sqrt((pixel_trajectories - smooth_trajectories + epsilon)**2))

    magnitude = torch.sqrt(dx.mean()**2 + dy.mean()**2)
    indicator = max(-1.93 * magnitude + 0.95,0)
    return indicator * smoothness + similarity


def spatial_loss(net_out):
    '''
    input:
        net_out/W: networks output warp fields W torch.Size([1, 160, 128,128])
    '''
    epsilon = 1e-8
    device = net_out.device  # Get the device of net_out tensor
    b, c, h, w = net_out.shape
    dx = net_out[:, ::2, :, :]
    dy = net_out[:, 1::2, :, :]
    W = torch.stack([dx, dy], dim=-1).to(device)  # Move to the same device
    W_fft = torch.fft.fft2(W)
    # Shift the zero-frequency component to the center of the spectrum
    W_fft_shifted = torch.fft.fftshift(W_fft)

    magnitude_spectrum = torch.abs(W_fft_shifted)
    mu = 0
    sigma = 3
    x = torch.arange(h)
    y = torch.arange(w)
    xx, yy = torch.meshgrid(x, y, indexing='ij')
    mask = torch.exp(-0.5 * ((xx - mu) ** 2 + (yy - mu) ** 2) / sigma ** 2)
    inverted_mask = (mask.max() - mask) / (mask.max() + epsilon)
    inverted_mask -= inverted_mask.mean()
    inverted_mask = inverted_mask.unsqueeze(0).unsqueeze(1).unsqueeze(-1).repeat(b,in_flows, 1, 1, 2).to(device)  # Move to the same device
    magnitude_spectrum = inverted_mask * magnitude_spectrum
    magnitude_spectrum.shape
    loss = torch.mean(torch.sqrt((magnitude_spectrum + epsilon)**2))
    return loss

In [6]:
def show_flow(flow):
    hsv_mask = np.zeros(shape= flow.shape[:-1] +(3,),dtype = np.uint8)
    hsv_mask[...,1] = 255
    mag, ang = cv2.cartToPolar(flow[...,0], flow[...,1],angleInDegrees=True)
    hsv_mask[:,:,0] = ang /2 
    hsv_mask[:,:,2] = cv2.normalize(mag,None,0,255,cv2.NORM_MINMAX)
    rgb = cv2.cvtColor(hsv_mask,cv2.COLOR_HSV2RGB)
    return(rgb)

In [None]:
dataset_len = len(train_ds.trainlist)
cv2.namedWindow('window',cv2.WINDOW_NORMAL)
for epoch in range(starting_epoch,EPOCHS):
    if epoch > 0 :
        for param_group in optimizer.param_groups:
                param_group['lr'] = 1e-5
    running_loss = 0
    #for idx,f in enumerate(data_loader):
    for idx,f in enumerate(train_loader):
        loss = 0
        f = f.to('cuda')
        w = model(f)
        lm = motion_loss(w,f)
        loss += lm
        if epoch < 1:
            lf =  10 * spatial_loss(w)
            loss += lf
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        pca =  f[0,40:42,:,:].cpu().detach().permute(1,2,0).numpy()
        pred = w[0,40:42,:,:].cpu().detach().permute(1,2,0).numpy()
        img1 = show_flow(pred)
        img2 = show_flow(pca)
        conc = cv2.hconcat([img1, img2])

        cv2.imshow('window',conc)
        if cv2.waitKey(1) & 0xFF == ord('9'):
            break    
        print(f'\repoch: {epoch} iter:{idx} loss: {running_loss / (idx % 1000 + 1) :.4f}',end ="")
        if idx % 1000 == 999:
            writer.add_scalar('training_loss',
                                running_loss / 1000,
                                epoch * dataset_len + idx)
            running_loss = 0.0
            model_path = os.path.join(ckpt_dir,f'mpi_{epoch}.pth')
            torch.save({'model':model.state_dict(),
                        'optimizer' : optimizer.state_dict(),
                        'epoch' : epoch}
                    ,model_path)
cv2.destroyAllWindows()