In [None]:
import torch
from torch.nn import functional as F
from models.Flowestimator import FlowEstimator
import pytorch_lightning as pl
from utils.warper import warper
from dataloader.sintelloader import SintelLoader
from utils.photometricloss import photometricloss
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from utils.averagemeter import AverageMeter
from torchvision.utils import make_grid
from utils.flow2rgb import flow2rgb
from utils.replicateto3channel import replicatechannel

In [2]:
class FlowTrainer(object):
    def __init__(self):
        super(FlowTrainer, self).__init__()
        # not the best model...
        self.model = FlowEstimator(shape = (256,256), 
                                   use_l2 = True, 
                                   channel_in = 3, 
                                   stride = 1, 
                                   kernel_size = 2, 
                                   use_cst = True)
        self.optimizer = None
        self.lr_scheduler = None
        self.save_dir = None
        
        self.epoch = 1
        
        self.train_loader = SintelLoader(batch_size = 20, 
                                         pin_memory = True, 
                                         num_workers = 8,
                                        nsample=100)
        
        self.val_loader = None
        
        self.test_loader = SintelLoader(sintel_root="/data/keshav/sintel/test/final",
                                        batch_size = 1, 
                                        pin_memory = True, 
                                        num_workers = 8)
        
        self.sample_test = [*SintelLoader(sintel_root="/data/keshav/sintel/test/final", 
                            test = True, nsample=10,visualize=True).load()][0]
        self.sample_train = [*SintelLoader(nsample=10,visualize=True).load()][0]
        
        self.sample_val = None
        
        self.save_model_path = './best/'
        self.load_model_path = None   
        self.best_metrics = {'train_loss':None, 
                             'val_loss':None}
        self.gpu_ids = [0,1,2,3,4,5,6,7]
        
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.02)
        self.scheduler = ReduceLROnPlateau(self.optimizer)
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        
        self.photoloss = torch.nn.MSELoss()
        
        self.writer = SummaryWriter()
        self.global_step = 0        
        
    def initialize(self):
        self.model.to(self.device)
        self.model = torch.nn.DataParallel(self.model, device_ids = self.gpu_ids)
        if self.load_model_path:
            #LOAD MODEL WEIGHTS HERE
            pass
        self.initialized = True
        
    
    def savemodel(self, metrics, compare = 'val_loss'):
        #Save model in save_model_path
        if self.best_metrics.get('val_loss') > metrics.get('val_loss'):
            #save only if new metrics are low
            self.best_metrics.update(metrics)
            pass
        else:
            # Load from the best saved
            pass
        
    def train_epoch_end(self,metrics):
        self.model.eval()
        
        with torch.no_grad():
            frame1 = self.sample_train['frame1'].to(self.device)
            frame2 = self.sample_train['frame2'].to(self.device)
            
            flow, occ = self.model(frame1, frame2)
            frame1_ = warper(flow, frame2)
            print(occ.shape)
            occ = replicatechannel(occ)
            print(occ.shape)
            print(self.sample_train['occlusion'].shape)
            sampocc = replicatechannel(self.sample_train['occlusion'].cuda())
            print(sampocc.shape)
            occs = torch.cat([sampocc, occ])
            
            frames = torch.cat([frame1_, frame1, frame2])
            frames = make_grid(frames, nrow=10).unsqueeze(0)
            
            flows = torch.cat([flow2rgb(flow.cpu()).cuda(),self.sample_train['flow'].cuda()])
            flows = make_grid(flows, nrow = 10).unsqueeze(0)
            
            self.writer.add_images('TRAIN/Frames',frames,metrics.get('nb_batch'))
            self.writer.add_images('TRAIN/Flows',flows,metrics.get('nb_batch'))
            self.writer.add_images('TRAIN/Occlusions',occs,metrics.get('nb_batch'))
        
        return self.val(metrics)
    
    def val_end(self,metrics):
        return metrics
        

            
    
    def test_end(self,metrics):
        with torch.no_grad():
            frame1 = self.sample_test['frame1'].to(self.device)
            frame2 = self.sample_test['frame2'].to(self.device)
            
            flow, occ = self.model(frame1, frame2)
            frame1_ = warper(flow, frame2)
            occ = replicatechannel(occ)
            
            frames = torch.cat([frame1_, frame1, frame2, flow2rgb(flow.cpu()).cuda(),occ])
            frames = make_grid(frames, nrow=10).unsqueeze(0)
            
            self.writer.add_images('TEST/Frames',frames,metrics.get('nb_batch'))
        return metrics
            
    
    def train(self, nb_epoch):
        trainstream = tqdm(self.train_loader.load())
        self.avg_loss = AverageMeter()
        self.model.train()
        for i,data in enumerate(trainstream):
            self.global_step += 1
            trainstream.set_description('TRAINING')
            
            #GET X and Frame 2
            #wdt = data['displacement'].to(self.device)
            frame2 = data['frame2'].to(self.device)
            frame1 = data['frame1'].to(self.device)
            
            self.optimizer.zero_grad()
            
            #forward
            with torch.set_grad_enabled(True):
                flow, occ = self.model(frame1, frame2)
                frame1_ = warper(flow, frame2)
                loss = photometricloss(frame1, frame1_, occ)
                self.avg_loss.update(loss.item(),i+1)
                loss.backward()
                self.optimizer.step()
                
                self.writer.add_scalar('Loss/train', 
                                       self.avg_loss.avg, self.global_step)
                
                trainstream.set_postfix({'epoch':nb_epoch, 
                                         'loss':self.avg_loss.avg})
        self.scheduler.step(loss)
        trainstream.close()            
        return self.train_epoch_end({'TRloss':self.avg_loss.avg,'epoch':nb_epoch,})

    
    def val(self,metrics):
        if self.val_loader is None:return self.test(metrics)
        #DO VAL STUFF HERE
        valstream = tqdm(self.val_loader.load())
        for data in valstream:
            pass
        return self.val_end(metrics)
    
    def test(self, metrics = {}):
        teststream = tqdm(self.test_loader.load())
        self.avg_loss = AverageMeter()
        with torch.no_grad():
            for i,data in enumerate(teststream):
                teststream.set_description('TESTING')
                frame2 = data['frame2'].to(self.device)
                frame1 = data['frame1'].to(self.device)
                flow, occ = self.model(frame1, frame2)
                frame1_ = warper(flow, frame2)
                loss = photometricloss(frame1, frame1_, occ)
                self.avg_loss.update(loss.item(),i+1)
                metrics.update({'TSloss':self.avg_loss.avg})
                teststream.set_postfix(metrics)
        
        self.writer.add_scalar('Loss/test', self.avg_loss.avg)
        teststream.close()
        
        return self.test_end(metrics)    
    
    def loggings(self,**metrics):
        pass
    
    def run(self):
        self.initialize()
        for i in range(self.epoch):
            metrics = self.train(i)
        self.test(metrics)
        self.writer.close()

In [3]:
ft = FlowTrainer()

In [None]:
torch.nn.Ca

In [4]:
ft.run()

TRAINING: 100%|██████████| 1/1 [00:37<00:00, 37.35s/it, epoch=0, loss=nan]


torch.Size([10, 1, 256, 256])
torch.Size([10, 3, 256, 256])
torch.Size([10, 1, 256, 256])
torch.Size([10, 3, 256, 256])


TESTING: 100%|██████████| 552/552 [00:06<00:00, 89.49it/s, TRloss=nan, epoch=0, TSloss=nan] 
TESTING: 100%|██████████| 552/552 [00:06<00:00, 81.21it/s, TRloss=nan, epoch=0, TSloss=nan] 


In [None]:
import re
from PIL import Image
from torchvision.transforms import ToPILImage, ToTensor
import torch
tostr = lambda x:int(''.join(re.findall(r'\d+',x.as_posix())))
maptotensor = lambda x:ToTensor()(Image.open(x).resize((256,256))).unsqueeze(0)

In [None]:
from pathlib import Path
pth = [*Path("/data/keshav/sintel/training/occlusions/alley_1/").glob('./*.png')]
pth = [*map(lambda x:x.as_posix(),sorted(pth, key = tostr))]
pth = pth[:5]

In [None]:
occten = torch.cat([*map(maptotensor, pth)],0)

In [None]:
ToPILImage()(occten[0])

In [None]:
occten.shape

In [None]:
occbig = occten.view(1,-1).repeat(3,1).view(3,5,256,256).permute(1,0,2,3)

In [None]:
occbig.shape

In [None]:
ToPILImage()(occbig[1])

In [None]:
import torch.nn as nn
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.up = nn.Upsample(scale_factor=2, mode='nearest')

        self.conv11 = nn.Conv2d(1, 128, kernel_size=3, padding=1)
        self.conv12 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.conv13 = nn.Conv2d(256, 2, kernel_size=3, padding=1)  



    def forward(self, x):
        in_size = x.size(0)

        x = F.relu(self.conv11(x))
        x = F.relu(self.conv12(x))
        x = F.relu(self.conv13(x))

        x = F.softmax(x, 1) #this line is changed

        return x

net = Net()
inputs = 0.5 - torch.rand(1,1,4,4)
print(inputs)
out = net(inputs)
print (out)
out.sum(dim=1)