# Train Model

In [1]:
import numpy as np
import os
import yaml
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
from torch.autograd import Variable
from networks.data_utils import get_imdb_data
from torch.utils.data import DataLoader, Dataset
import random
import torch.nn.functional as F
from networks.relay_net import ReLayNet
from networks.solver import Solver



In [2]:

class RandomTransforms(object):
    def __init__(self, height, width, layers, prob1=0.5, prob2=0.5, prob3 = 0.5, h_size= 490, w_size= 60):
        self.height = height
        self.width = width
        self.layers = layers
        self.prob = prob1
        self.prob2 = prob2
        self.prob3 = prob3
        self.translate_prob = 0.5
        self.h_size = h_size
        self.w_size = w_size

    def __call__(self, image, target, weight):
        if random.random() < self.prob:

            image = np.flip(image,1)
            target = np.flip(target,1)
            weight = np.flip(weight,1)
            return image, target, weight
        
        if random.random()<self.prob2:

            y,x,c = image.shape
            startx = x//2 - self.w_size//2
            starty = y//2 - self.h_size//2   
            image = image[starty:starty+self.h_size, startx:startx+self.w_size, :].copy()
            target = target[starty:starty+self.h_size, startx:startx+self.w_size, :].copy()
            weight = weight[starty:starty+self.h_size, startx:startx+self.w_size].copy()
            image = np.resize(image, (self.height,self.width,1))
            target = np.resize(target, (self.height,self.width,self.layers))
            weight = np.resize(weight, (self.height,self.width))
            return image, target, weight
            
        if random.random()<self.prob3:

            if random.random()<self.translate_prob:

                pad_image = np.zeros(image.shape)
                pad_target = np.zeros(target.shape)
                pad_weight = np.zeros(weight.shape)
                y,x,c = image.shape
                startx = x//2 - self.w_size//2
                image = image[:, startx:startx+self.w_size, :].copy()
                pad_image[:,:image.shape[1],:] = image
                target = target[:, startx:startx+self.w_size, :].copy()
                pad_target[:,:image.shape[1],:] = target
                weight = weight[:, startx:startx+self.w_size].copy()
                weight[:,:image.shape[1]] = weight
                
                return pad_image, pad_target, pad_weight
            else:

                pad_image = np.zeros(image.shape)
                pad_target = np.zeros(target.shape)
                pad_weight = np.zeros(weight.shape)
                y,x,c = image.shape
                starty = y//2 - self.h_size//2 
                image = image[starty:starty+self.h_size, :, :].copy()
                pad_image[:image.shape[0],:,:] = image
                target = target[starty:starty+self.h_size, :, :].copy()
                pad_target[:image.shape[0],:,:] = target
                weight = weight[starty:starty+self.h_size, :].copy()
                pad_weight[:weight.shape[0],:] = weight
                
                return pad_image, pad_target, pad_weight
        return image, target, weight
                
class ImdbData(Dataset):
    
    def __init__(self, config, X, y, W, transform):
        self.X = X
        self.y = y
        self.w = W
        self.height = config['general']['HEIGHT']
        self.width = config['general']['WIDTH']
        self.layers = config['general']['layers']
        self.transform = transform(self.height, self.width, self.layers)

    def __getitem__(self, index):
        img = np.transpose(self.X[index], (1,2,0)) 
        label = np.transpose(self.y[index],(1,2,0))
        weight = self.w[index]
        if self.transform:

            img, label, weight = self.transform(img, label, weight)

        img = torch.from_numpy(img.copy()).float().permute(2,0,1)
        label = torch.from_numpy(label.copy()).long().permute(2,0,1)
        weight = torch.from_numpy(weight.copy()).float()
    
        return img, label, weight


    def __len__(self):
        return len(self.X)

In [3]:
with open( "./train.yaml") as file:
    config = yaml.load(file, Loader=yaml.FullLoader)
HEIGHT = config['general']['HEIGHT']
WIDTH = config['general']['WIDTH']
layers = config['general']['layers']
exp_dir_name = config['filepaths']['exp_dir_name']

model_path = config['filepaths']['model_path']
data_dir = config['filepaths']['processed_data_path']
param = config['param']

train_images, train_labels, train_wmaps, val_images, val_labels, val_wmaps = get_imdb_data(data_dir)

train_images2 = np.copy(np.expand_dims(train_images.reshape(-1,HEIGHT, WIDTH), axis = 1))
train_labels2 = np.copy(train_labels.reshape(-1, layers, HEIGHT, WIDTH))
train_wmaps2 = np.copy(train_wmaps.reshape(-1, HEIGHT, WIDTH))

val_images2 = np.copy(np.expand_dims(val_images.reshape(-1,HEIGHT, WIDTH), axis = 1))
val_labels2 = np.copy(val_labels.reshape(-1, layers, HEIGHT, WIDTH))
val_wmaps2 = np.copy(val_wmaps.reshape(-1, HEIGHT, WIDTH))
    


In [4]:
train_dataset = ImdbData(config, train_images2, train_labels2, train_wmaps2, transform = RandomTransforms)
val_dataset = ImdbData(config, val_images2, val_labels2, val_wmaps2, transform = RandomTransforms)


In [5]:

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=4)

relaynet_model = ReLayNet(param)
solver = Solver(optim_args={"lr": 1e-2})
        
solver.train(relaynet_model, train_loader, val_loader, model_path, log_nth=1, num_epochs=20, exp_dir_name=exp_dir_name)

START TRAIN.


  0%|          | 3/1470 [00:00<02:50,  8.61it/s]

tensor(4.1091, device='cuda:0', grad_fn=<AddBackward0>)
tensor(4.5328, device='cuda:0', grad_fn=<AddBackward0>)
tensor(4.2815, device='cuda:0', grad_fn=<AddBackward0>)


  0%|          | 5/1470 [00:00<02:35,  9.41it/s]

tensor(3.9424, device='cuda:0', grad_fn=<AddBackward0>)
tensor(3.7465, device='cuda:0', grad_fn=<AddBackward0>)
tensor(3.5494, device='cuda:0', grad_fn=<AddBackward0>)


  1%|          | 9/1470 [00:00<02:14, 10.87it/s]

tensor(3.2880, device='cuda:0', grad_fn=<AddBackward0>)
tensor(3.0603, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2.7805, device='cuda:0', grad_fn=<AddBackward0>)


  1%|          | 11/1470 [00:01<02:08, 11.37it/s]

tensor(2.2831, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2.6026, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2.2054, device='cuda:0', grad_fn=<AddBackward0>)


  1%|          | 15/1470 [00:01<01:59, 12.21it/s]

tensor(1.9537, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2.3023, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.9117, device='cuda:0', grad_fn=<AddBackward0>)


  1%|          | 17/1470 [00:01<01:55, 12.54it/s]

tensor(1.7887, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.9411, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.6191, device='cuda:0', grad_fn=<AddBackward0>)


  1%|▏         | 21/1470 [00:01<01:51, 13.04it/s]

tensor(1.5142, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.5753, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.5686, device='cuda:0', grad_fn=<AddBackward0>)


  2%|▏         | 23/1470 [00:01<01:56, 12.38it/s]

tensor(1.4198, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.3835, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.5229, device='cuda:0', grad_fn=<AddBackward0>)


  2%|▏         | 27/1470 [00:02<01:54, 12.56it/s]

tensor(1.3830, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.3659, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.3752, device='cuda:0', grad_fn=<AddBackward0>)


  2%|▏         | 29/1470 [00:02<01:52, 12.83it/s]

tensor(1.3854, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.2572, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.3441, device='cuda:0', grad_fn=<AddBackward0>)


  2%|▏         | 33/1470 [00:02<01:49, 13.18it/s]

tensor(1.3061, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.2749, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.3322, device='cuda:0', grad_fn=<AddBackward0>)


  2%|▏         | 35/1470 [00:02<01:48, 13.28it/s]

tensor(1.2490, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.2057, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.3011, device='cuda:0', grad_fn=<AddBackward0>)


  3%|▎         | 39/1470 [00:03<01:54, 12.51it/s]

tensor(1.2355, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.2429, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.2414, device='cuda:0', grad_fn=<AddBackward0>)


  3%|▎         | 41/1470 [00:03<01:56, 12.22it/s]

tensor(1.1483, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.3445, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.1899, device='cuda:0', grad_fn=<AddBackward0>)


  3%|▎         | 45/1470 [00:03<01:58, 12.06it/s]

tensor(1.2159, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.1672, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.2384, device='cuda:0', grad_fn=<AddBackward0>)


  3%|▎         | 47/1470 [00:03<01:54, 12.38it/s]

tensor(1.3088, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.1722, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.2878, device='cuda:0', grad_fn=<AddBackward0>)


  3%|▎         | 51/1470 [00:04<01:56, 12.18it/s]

tensor(1.1521, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.2720, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.2424, device='cuda:0', grad_fn=<AddBackward0>)


  4%|▎         | 53/1470 [00:04<01:54, 12.38it/s]

tensor(1.1885, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.1742, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.2307, device='cuda:0', grad_fn=<AddBackward0>)


  4%|▍         | 57/1470 [00:04<01:49, 12.86it/s]

tensor(1.2046, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.2080, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.1473, device='cuda:0', grad_fn=<AddBackward0>)


  4%|▍         | 59/1470 [00:04<01:49, 12.86it/s]

tensor(1.1447, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.2451, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.1812, device='cuda:0', grad_fn=<AddBackward0>)


  4%|▍         | 63/1470 [00:05<01:59, 11.80it/s]

tensor(1.1860, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.4174, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.1860, device='cuda:0', grad_fn=<AddBackward0>)


  4%|▍         | 65/1470 [00:05<01:56, 12.09it/s]

tensor(1.1983, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.1934, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.1914, device='cuda:0', grad_fn=<AddBackward0>)


  5%|▍         | 69/1470 [00:05<01:55, 12.09it/s]

tensor(1.1286, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.2208, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.1724, device='cuda:0', grad_fn=<AddBackward0>)


  5%|▍         | 71/1470 [00:05<01:54, 12.19it/s]

tensor(1.1364, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.2474, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.1365, device='cuda:0', grad_fn=<AddBackward0>)


  5%|▌         | 75/1470 [00:06<01:58, 11.81it/s]

tensor(1.2432, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.1601, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.1360, device='cuda:0', grad_fn=<AddBackward0>)


  5%|▌         | 77/1470 [00:06<01:54, 12.13it/s]

tensor(1.1401, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.1272, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.1751, device='cuda:0', grad_fn=<AddBackward0>)


  6%|▌         | 81/1470 [00:06<01:52, 12.32it/s]

tensor(1.1082, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.1859, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.2003, device='cuda:0', grad_fn=<AddBackward0>)


  6%|▌         | 83/1470 [00:06<01:54, 12.09it/s]

tensor(1.1536, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.1555, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.1808, device='cuda:0', grad_fn=<AddBackward0>)


  6%|▌         | 87/1470 [00:07<01:54, 12.03it/s]

tensor(1.1401, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.4463, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.1816, device='cuda:0', grad_fn=<AddBackward0>)


  6%|▌         | 89/1470 [00:07<01:54, 12.04it/s]

tensor(1.1996, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.1548, device='cuda:0', grad_fn=<AddBackward0>)





KeyboardInterrupt: 