# Init

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import datetime
import argparse
import os
import time

import torch
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms

import flow_transforms
from torch.utils.tensorboard import SummaryWriter

#  Импорт файлов датасета

In [None]:
import zipfile

with zipfile.ZipFile(file='./drive/MyDrive/dataset_chairs.zip') as zipnik:
  zipnik.extractall()

# Константы

In [None]:
# value by which flow will be divided. Original value is 20 but 1 with
# batchNorm gives good results
DIV_FLOW = 20

best_EPE = -1

start_epoch = 0
n_iter = start_epoch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

TRAIN_RATIO = 0.8
BATCH_SIZE = 8
NUM_WORKERS = 2
EPOCH_SIZE = 1000
EPOCHS = 30
LR = 10e-4
MOMENTUM = 0.9
BETA = 0.999
WEIGHT_DECAY = 4e-4
BIAS_DECAY = 0
MULTISCALE_WEIGHTS = [0.005,0.01,0.02,0.08,0.32]
PRINT_FREQ = 10
MILESTONES = [10, 20]


save_path = 'model_'

# Запись прогресса обучения ?

In [None]:
train_writer = SummaryWriter(os.path.join(save_path,'train'))
test_writer = SummaryWriter(os.path.join(save_path,'test'))
output_writers = []
for i in range(3):
  output_writers.append(SummaryWriter(os.path.join(save_path,'test',str(i))))


# Преобразования для загрузки изображений датасета

In [None]:
 # Data loading code
input_transform = transforms.Compose([
        flow_transforms.ArrayToTensor(),
        transforms.Normalize(mean=[0,0,0], std=[255,255,255]),
        transforms.Normalize(mean=[0.45,0.432,0.411], std=[1,1,1])
    ])
target_transform = transforms.Compose([
        flow_transforms.ArrayToTensor(),
        transforms.Normalize(mean=[0,0],std=[DIV_FLOW, DIV_FLOW])
    ])

# Создаём пайплайн [преобразований изображений](https://github.com/DimaKurd/FlowNetPytorch/blob/master/flow_transforms.py) для обучения

In [None]:
co_transform = flow_transforms.Compose([
            flow_transforms.RandomTranslate(10),
            flow_transforms.RandomRotate(10,5),
            flow_transforms.RandomCrop((320,448)),
            flow_transforms.RandomVerticalFlip(),
            flow_transforms.RandomHorizontalFlip()
        ])


# Грузим данные, создаём даталоадеры

In [None]:
import os.path
import glob
from listdataset import ListDataset
from util_dataset import split2list
import numpy as np



def make_dataset(dir, split=None):
    '''Will search for triplets that go by the pattern '[name]_img1.ppm  [name]_img2.ppm    [name]_flow.flo' '''
    images = []
    for flow_map in sorted(glob.glob(os.path.join(dir,'*_flow.flo'))):
        flow_map = os.path.basename(flow_map)
        root_filename = flow_map[:-9]
        img1 = root_filename+'_img1.ppm'
        img2 = root_filename+'_img2.ppm'
        if not (os.path.isfile(os.path.join(dir,img1)) and os.path.isfile(os.path.join(dir,img2))):
            continue

        images.append([[img1,img2],flow_map])

    return split2list(images, split, default_split=0.97)


def flying_chairs(root, transform=None, target_transform=None,
                  co_transform=None, split=None):
    train_list, test_list = make_dataset(root,split)
    train_dataset = ListDataset(root, train_list, transform, target_transform, co_transform)
    test_dataset = ListDataset(root, test_list, transform, target_transform)

    return train_dataset, test_dataset

In [None]:
train_set, test_set = flying_chairs(
        'dataset',
        transform=input_transform,
        target_transform=target_transform,
        co_transform=co_transform,
        split=TRAIN_RATIO
    )

In [None]:
print(f'Train len: {len(train_set)}. Test len {len(test_set)}')

Train len: 1228. Test len 267


In [None]:
train_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        shuffle=True)
val_loader = torch.utils.data.DataLoader(
        test_set,
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        shuffle=False)


# Объявление модели

In [None]:
import torch
import torch.nn as nn
from torch.nn.init import kaiming_normal_, constant_
from util_model import conv, predict_flow, deconv, crop_like

class FlowNetS(nn.Module):
    expansion = 1

    def __init__(self,batchNorm=True):
        super(FlowNetS,self).__init__()

        self.batchNorm = batchNorm
        self.conv1   = conv(self.batchNorm,   6,   64, kernel_size=7, stride=2)
        self.conv2   = conv(self.batchNorm,  64,  128, kernel_size=5, stride=2)
        self.conv3   = conv(self.batchNorm, 128,  256, kernel_size=5, stride=2)
        self.conv3_1 = conv(self.batchNorm, 256,  256)
        self.conv4   = conv(self.batchNorm, 256,  512, stride=2)
        self.conv4_1 = conv(self.batchNorm, 512,  512)
        self.conv5   = conv(self.batchNorm, 512,  512, stride=2)
        self.conv5_1 = conv(self.batchNorm, 512,  512)
        self.conv6   = conv(self.batchNorm, 512, 1024, stride=2)
        self.conv6_1 = conv(self.batchNorm,1024, 1024)

        self.deconv5 = deconv(1024,512)
        self.deconv4 = deconv(1026,256)
        self.deconv3 = deconv(770,128)
        self.deconv2 = deconv(386,64)

        self.predict_flow6 = predict_flow(1024)
        self.predict_flow5 = predict_flow(1026)
        self.predict_flow4 = predict_flow(770)
        self.predict_flow3 = predict_flow(386)
        self.predict_flow2 = predict_flow(194)

        self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
        self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
        self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
        self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)

        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                kaiming_normal_(m.weight, 0.1)
                if m.bias is not None:
                    constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                constant_(m.weight, 1)
                constant_(m.bias, 0)

    def forward(self, x):
        out_conv2 = self.conv2(self.conv1(x))
        out_conv3 = self.conv3_1(self.conv3(out_conv2))
        out_conv4 = self.conv4_1(self.conv4(out_conv3))
        out_conv5 = self.conv5_1(self.conv5(out_conv4))
        out_conv6 = self.conv6_1(self.conv6(out_conv5))

        flow6       = self.predict_flow6(out_conv6)
        flow6_up    = crop_like(self.upsampled_flow6_to_5(flow6), out_conv5)
        out_deconv5 = crop_like(self.deconv5(out_conv6), out_conv5)

        concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1)
        flow5       = self.predict_flow5(concat5)
        flow5_up    = crop_like(self.upsampled_flow5_to_4(flow5), out_conv4)
        out_deconv4 = crop_like(self.deconv4(concat5), out_conv4)

        concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1)
        flow4       = self.predict_flow4(concat4)
        flow4_up    = crop_like(self.upsampled_flow4_to_3(flow4), out_conv3)
        out_deconv3 = crop_like(self.deconv3(concat4), out_conv3)

        concat3 = torch.cat((out_conv3,out_deconv3,flow4_up),1)
        flow3       = self.predict_flow3(concat3)
        flow3_up    = crop_like(self.upsampled_flow3_to_2(flow3), out_conv2)
        out_deconv2 = crop_like(self.deconv2(concat3), out_conv2)

        concat2 = torch.cat((out_conv2,out_deconv2,flow3_up),1)
        flow2 = self.predict_flow2(concat2)

        if self.training:
            return flow2,flow3,flow4,flow5,flow6
        else:
            return flow2

    def weight_parameters(self):
        return [param for name, param in self.named_parameters() if 'weight' in name]

    def bias_parameters(self):
        return [param for name, param in self.named_parameters() if 'bias' in name]




# Инициализируем модель

In [None]:
model = FlowNetS(batchNorm=False).to(device)

In [None]:
param_groups = [{'params': model.bias_parameters(), 'weight_decay': BIAS_DECAY},
                {'params': model.weight_parameters(), 'weight_decay': WEIGHT_DECAY}]

if device.type == "cuda":
        model = torch.nn.DataParallel(model).cuda()
        cudnn.benchmark = True

In [None]:
optimizer = torch.optim.Adam(param_groups, LR, betas=(MOMENTUM, BETA))

In [None]:
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=MILESTONES, gamma=0.5)

# Обучение модели

In [None]:
from util import flow2rgb, AverageMeter, save_checkpoint
from multiscaleloss import multiscaleEPE, realEPE


def train(train_loader, model, optimizer, epoch, train_writer):
    global n_iter
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    flow2_EPEs = AverageMeter()

    epoch_size = len(train_loader)

    # switch to train mode
    model.train()
    end = time.time()

    for i, (input, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        target = target.to(device)
        input = torch.cat(input,1).to(device)

        # compute output
        output = model(input)
        # if args.sparse:
        #     # Since Target pooling is not very precise when sparse,
        #     # take the highest resolution prediction and upsample it instead of downsampling target
        #     h, w = target.size()[-2:]
        #     output = [F.interpolate(output[0], (h,w)), *output[1:]]

        loss = multiscaleEPE(output, target, weights=MULTISCALE_WEIGHTS, sparse=False)
        flow2_EPE = DIV_FLOW * realEPE(output[0], target, sparse=False)
        # record loss and EPE
        losses.update(loss.item(), target.size(0))
        train_writer.add_scalar('train_loss', loss.item(), n_iter)
        flow2_EPEs.update(flow2_EPE.item(), target.size(0))

        # compute gradient and do optimization step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % PRINT_FREQ == 0:
            print('Epoch: [{0}][{1}/{2}]\t Time {3}\t Data {4}\t Loss {5}\t EPE {6}'
                  .format(epoch, i, epoch_size, batch_time,
                          data_time, losses, flow2_EPEs))
        n_iter += 1
        if i >= epoch_size:
            break

    return losses.avg, flow2_EPEs.avg


def validate(val_loader, model, epoch, output_writers):


    batch_time = AverageMeter()
    flow2_EPEs = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    for i, (input, target) in enumerate(val_loader):
        target = target.to(device)
        input = torch.cat(input,1).to(device)

        # compute output
        output = model(input)
        flow2_EPE = DIV_FLOW*realEPE(output, target, sparse=False)
        # record EPE
        flow2_EPEs.update(flow2_EPE.item(), target.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i < len(output_writers):  # log first output of first batches
            if epoch == start_epoch:
                mean_values = torch.tensor([0.45,0.432,0.411], dtype=input.dtype).view(3,1,1)
                output_writers[i].add_image('GroundTruth', flow2rgb(DIV_FLOW * target[0], max_value=10), 0)
                output_writers[i].add_image('Inputs', (input[0,:3].cpu() + mean_values).clamp(0,1), 0)
                output_writers[i].add_image('Inputs', (input[0,3:].cpu() + mean_values).clamp(0,1), 1)
            output_writers[i].add_image('FlowNet Outputs', flow2rgb(DIV_FLOW * output[0], max_value=10), epoch)

        if i % PRINT_FREQ == 0:
            print('Test: [{0}/{1}]\t Time {2}\t EPE {3}'
                  .format(i, len(val_loader), batch_time, flow2_EPEs))

    print(' * EPE {:.3f}'.format(flow2_EPEs.avg))

    return flow2_EPEs.avg



for epoch in range(n_iter, EPOCHS):
        scheduler.step()

        # train for one epoch
        train_loss, train_EPE = train(train_loader, model, optimizer, epoch, train_writer)
        train_writer.add_scalar('mean EPE', train_EPE, epoch)

        # evaluate on validation set

        with torch.no_grad():
            EPE = validate(val_loader, model, epoch, output_writers)
        test_writer.add_scalar('mean EPE', EPE, epoch)

        if best_EPE < 0:
            best_EPE = EPE

        is_best = EPE < best_EPE
        best_EPE = min(EPE, best_EPE)
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': 'FlowNetS',
            'state_dict': model.module.state_dict(),
            'best_EPE': best_EPE,
            'div_flow': DIV_FLOW
        }, is_best, save_path)




Epoch: [0][0/154]	 Time 20.022 (20.022)	 Data 3.073 (3.073)	 Loss 109.026 (109.026)	 EPE 22.229 (22.229)
Epoch: [0][10/154]	 Time 0.241 (2.943)	 Data 0.000 (1.258)	 Loss 345.219 (615.874)	 EPE 16.695 (34.599)
Epoch: [0][20/154]	 Time 0.259 (2.374)	 Data 0.000 (1.410)	 Loss 296.462 (411.514)	 EPE 17.475 (27.403)
Epoch: [0][30/154]	 Time 0.238 (2.193)	 Data 0.000 (1.480)	 Loss 153.376 (379.840)	 EPE 17.500 (25.467)
Epoch: [0][40/154]	 Time 0.211 (2.066)	 Data 0.005 (1.485)	 Loss 1660.138 (657.372)	 EPE 26.589 (28.538)
Epoch: [0][50/154]	 Time 0.254 (2.003)	 Data 0.000 (1.500)	 Loss 5822.542 (1030.561)	 EPE 140.532 (32.478)
Epoch: [0][60/154]	 Time 0.236 (1.988)	 Data 0.000 (1.537)	 Loss 15088.424 (5669.139)	 EPE 302.290 (120.543)
Epoch: [0][70/154]	 Time 0.243 (2.003)	 Data 0.000 (1.589)	 Loss 1665893.250 (37391.799)	 EPE 38892.789 (765.628)
Epoch: [0][80/154]	 Time 0.240 (1.973)	 Data 0.004 (1.588)	 Loss 96297.312 (60697.474)	 EPE 499.145 (926.225)
Epoch: [0][90/154]	 Time 0.241 (1.942)

# Validate code

In [None]:
best_EPE = validate(val_loader, model, 0, output_writers)

Test: [0/34]	 Time 0.681 (0.681)	 EPE 16.380 (16.380)
Test: [10/34]	 Time 0.216 (0.320)	 EPE 13.809 (12.622)
Test: [20/34]	 Time 0.135 (0.262)	 EPE 16.868 (12.818)
Test: [30/34]	 Time 0.178 (0.248)	 EPE 14.268 (12.078)
 * EPE 12.096


# Saving model on disk

In [None]:
! cp ./model_/model_best.pth.tar ./drive/MyDrive/flownets.pth.tar