# Seq2bg

## Imports

In [1]:
import torch
import torch.nn as nn
from torch.optim import Adam
from Seq2Seq import Seq2Seq
from torch.utils.data import DataLoader
import io
#import imageio
from ipywidgets import widgets, HBox

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import os
import cv2

# Use GPU if available
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Load data

In [3]:
input_path = "highway/input"
gt_path = "highway/groundtruth/"

# Llegim totes les imatges d'input i les incorporam a una ED de numpy
imatges_input = os.listdir(input_path)
imatges_input.sort()

seq_len = 4
n_seq = 0
n_seqs = len(imatges_input) // seq_len

# batch_size, num_channels, seq_len, height, width
input_data = np.zeros((n_seqs, 3, seq_len, 240, 320))#, dtype=np.uint8)

for idx, imatge in enumerate(imatges_input):
    aux = cv2.imread(input_path +"/" + imatge)
    index = idx 
    for s in range(seq_len):
        
        input_data[n_seq,:,s,:,:] = np.moveaxis(aux, -1, 0)
        
    n_seq = (n_seq + 1) % seq_len
    
# Llegim totes les imatges de GT i les incorporam a una ED de numpy
input_gt = np.zeros((n_seqs, 1, 240, 320))
imatges_gt = os.listdir(gt_path)
imatges_gt.sort()

n_seq = 0
for idx in range(0, n_seqs, seq_len):
    aux = cv2.imread(gt_path +"/" + imatges_gt[idx], cv2.IMREAD_GRAYSCALE)
    input_gt[n_seq,:,:,:] = aux
    n_seq += 1

In [4]:
dataset = list(zip(input_data,input_gt))

In [5]:
def collate(batch):
    #(batch_size, num_channels, seq_len, height, width)
    seq_len = 4
    
    in_image = np.zeros((len(batch), 3, seq_len, 240, 320)) #, dtype=np.uint8)    
    for i in range(len(batch)):
        in_image[i,:,:,:] = batch[i][0]

    return torch.tensor(in_image,dtype=torch.float32), torch.tensor(batch[-1][1],dtype=torch.float32)


# Training Data Loader
train_loader = DataLoader(dataset, shuffle=False, batch_size=4, collate_fn=collate)

# Validation Data Loader
#val_loader = DataLoader(val_data, shuffle=True, 
#                        batch_size=16, collate_fn=collate)
iterador = iter(train_loader)
data, label = next(iterador)


## Entrenament


In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# The input video frames are grayscale, thus single channel

model = Seq2Seq(num_channels=3, num_kernels=64, 
kernel_size=(3, 3), padding=(1, 1), activation="relu", 
frame_size=(240, 320), num_layers=3).to(device)

optim = Adam(model.parameters(), lr=1e-4)

# Binary Cross Entropy, target pixel values either 0 or 1
criterion = nn.BCELoss(reduction='sum')

In [7]:
num_epochs = 20

for epoch in range(1, num_epochs+1):
    
    train_loss = 0                                                 
    model.train()                                                  
    for batch_num, (input_img, target) in enumerate(train_loader, 1):  
        print(input_img.dtype)
        output = model(input_img)                                     
        loss = criterion(output.flatten(), target.flatten()) 
        loss.backward()                                            
        optim.step()                                               
        optim.zero_grad()                                           
        train_loss += loss.item()    
        
    train_loss /= len(train_loader.dataset)                       

    #val_loss = 0                                                 
    #model.eval()                                                   
    #with torch.no_grad():                                          
    #    for input, target in val_loader:                          
    #        output = model(input)                                   
    #        loss = criterion(output.flatten(), target.flatten())   
    #        val_loss += loss.item()                                
    #val_loss /= len(val_loader.dataset)                            

    print("Epoch:{} Training Loss:{:.2f}")# Validation Loss:{:.2f}\n".format(
    #   epoch, train_loss, val_loss))

torch.float32


KeyboardInterrupt: 