In [66]:
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
import torch.nn.modules.normalization as norm
from torch.autograd import Variable

In [60]:
LSTM_SIZE = 512

In [61]:
class LRN(nn.Module):
    def __init__(self, local_size=1, alpha=1.0, beta=0.75, ACROSS_CHANNELS=True):
        super(LRN, self).__init__()
        self.ACROSS_CHANNELS = ACROSS_CHANNELS
        if ACROSS_CHANNELS:
            self.average=nn.AvgPool3d(kernel_size=(local_size, 1, 1),
                    stride=1,
                    padding=(int((local_size-1.0)/2), 0, 0))
        else:
            self.average=nn.AvgPool2d(kernel_size=local_size,
                    stride=1,
                    padding=int((local_size-1.0)/2))
        self.alpha = alpha
        self.beta = beta


    def forward(self, x):
        if self.ACROSS_CHANNELS:
            div = x.pow(2).unsqueeze(1)
            div = self.average(div).squeeze(1)
            div = div.mul(self.alpha).add(1.0).pow(self.beta)
        else:
            div = x.pow(2)
            div = self.average(div)
            div = div.mul(self.alpha).add(1.0).pow(self.beta)
        x = x.div(div)
        return x

In [62]:
class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

In [72]:
#alexnet = models.alexnet(pretrained=True)
class alexnet_conv_layers(nn.Module):
    def __init__(self):
        super(alexnet_conv_layers, self).__init__()
        input_channels = 3
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(input_channels, out_channels=96, kernel_size=11, stride=4),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            norm.LocalResponseNorm(size=2, alpha=2e-5, beta=0.75, k=1.0)
        )
        self.skip1 = nn.Sequential(
            nn.Conv2d(96, out_channels=16, kernel_size=1, stride=1),
            nn.PReLU(),
            Flatten()
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=96, out_channels=256, groups=2, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            norm.LocalResponseNorm(size=2, alpha=2e-5, beta=0.75, k=1.0)
        )

        self.skip2 = nn.Sequential(
            nn.Conv2d(256, out_channels=32, kernel_size=1, stride=1),
            nn.PReLU(),
            Flatten()
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, padding=1),
            nn.ReLU()
        )

        self.conv4 = nn.Sequential(
            nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, padding=1, groups=2),
            nn.ReLU()
        )

        self.conv5 = nn.Sequential(
            nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, padding=1, groups=2),
            nn.ReLU()
        )

        self.pool5 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=2)
        )

        self.conv5_flat = nn.Sequential(
            Flatten()
        )

        self.skip5 = nn.Sequential(
            nn.Conv2d(256, out_channels=64, kernel_size=1, stride=1),
            nn.PReLU(),
            Flatten()
        )

        self.conv6 = nn.Sequential(
            nn.Linear(37104 * 2, 2048),
            nn.ReLU()
        )

    def forward(self, x):
        x_out1 = self.conv1(x)
        x_out_skip1 = self.skip1(x_out1)

        x_out2 = self.conv2(x_out1)
        x_out_skip2 = self.skip2(x_out2)

        x_out3 = self.conv3(x_out2)
        x_out4 = self.conv4(x_out3)
        x_out5 = self.conv5(x_out4)

        x_out_skip5 = self.skip5(x_out5)

        x_out_pool =self.pool5(x_out5)
        x_out_pool = self.conv5_flat( x_out_pool)
        x_out = torch.cat((x_out_skip1, x_out_skip2, x_out_skip5, x_out_pool), dim=1)

        y_out1 = self.conv1(x)
        y_out_skip1 = self.skip1(y_out1)

        y_out2 = self.conv2(y_out1)
        y_out_skip2 = self.skip2(y_out2)

        y_out3 = self.conv3(y_out2)
        y_out4 = self.conv4(y_out3)
        y_out5 = self.conv5(y_out4)

        y_out_skip5 = self.skip5(y_out5)

        y_out_pool =self.pool5(y_out5)
        y_out_pool = self.conv5_flat(y_out_pool)
        y_out = torch.cat((y_out_skip1, y_out_skip2, y_out_skip5, y_out_pool), dim=1)

        final_out = torch.cat((x_out, y_out), dim=1)
        conv_out = self.conv6(final_out)
        return conv_out

In [73]:
class Re3Net(nn.Module):
    def __init__(self):
        super(Re3Net,self).__init__()
        self.conv_layers = alexnet_conv_layers()
        
        #2048 from conv_layers? maybe 1024?
        self.lstm1 =nn.LSTMCell(2048, LSTM_SIZE)
        self.lstm2 = nn.LSTMCell(2048 + LSTM_SIZE, LSTM_SIZE)

        self.fc_final = nn.Linear(LSTM_SIZE,4)

        self.h0=Variable(torch.rand(1,1024))
        self.c0=Variable(torch.rand(1,1024))

    def init_hidden(self):
        self.h0 = Variable(torch.rand(1, 1024))
        self.c0 = Variable(torch.rand(1, 1024))

    def forward(self, x, prev_LSTM_state=False):
        out = self.conv_layers(x)

        lstm_out, self.h0 = self.lstm1(out, (self.h0, self.c0))

        lstm2_in = torch.cat((out, lstm_out), dim=1)

        lstm2_out, h1 = self.lstm2(lstm2_in, (self.h0, self.c0))

        out = self.fc_final(lstm2_out)
        return out

In [77]:
re3 = Re3Net()

In [79]:
class RunningAverage():
    """A simple class that maintains the running average of a quantity
    Example:
    ```
    loss_avg = RunningAverage()
    loss_avg.update(2)
    loss_avg.update(4)
    loss_avg() = 3
    ```
    """

    def __init__(self):
        self.steps = 0
        self.total = 0

    def update(self, val):
        self.total += val
        self.steps += 1

    def __call__(self):
        return self.total / float(self.steps)

In [83]:
def train(model, optimizer, loss_fn, dataloader, metrics, params):
    # set model to training mode
    model.train()

    # summary for current training loop and a running average object for loss
    summ = []
    loss_avg = RunningAverage()
    counter=0
    for i,data in enumerate(dataloader):
        optimizer.zero_grad()
        x1, x2, y = data['previmg'], data['currimg'], data['currbb']
        output = model(x1, x2)
        loss = loss_fn(output, y)
        loss.backward(retain_graph=True)
        # performs updates using calculated gradients
        optimizer.step()
        if i % params.save_summary_steps == 0:
            # extract data from torch Variable, move to cpu, convert to numpy arrays
            output = output.data.cpu().numpy()
            # compute all metrics on this batch
            summary_batch = {}
            summary_batch['loss'] = loss.data[0]
            summ.append(summary_batch)
            logging.info('- Average Loss for iteration {} is {}'.format(i,loss.data[0]/params.batch_size))

        # update the average loss
        loss_avg.update(loss.data[0])
        counter+=1

    print(counter)
    # compute mean of all metrics in summary
    metrics_mean = {metric: np.mean([x[metric] for x in summ]) for metric in summ[0]}
    metrics_string = " ; ".join("{}: {:05.3f}".format(k, v) for k, v in metrics_mean.items())
    logging.info("- Train metrics: " + metrics_string)

In [84]:
def train_model(net, dataloader, optim, loss_function, num_epochs):

    dataset_size = dataloader.dataset.len
    for epoch in range(num_epochs):
        net.train()
        curr_loss = 0.0

        # currently training on just ALOV dataset
        i = 0
        for data in dataloader:

            x1, x2, y = data['previmg'], data['currimg'], data['currbb']
            if use_gpu:
                x1, x2, y = Variable(x1.cuda()), Variable(x2.cuda()), Variable(y.cuda(), requires_grad=False)
            else:
                x1, x2, y = Variable(x1), Variable(x2), Variable(y, requires_grad=False)

            optim.zero_grad()

            output = net(x1,x2)
            loss = loss_function(output, y)

            loss.backward(retain_graph=True)
            optim.step()

            print('[training] epoch = %d, i = %d/%d, loss = %f' % (epoch, i, dataset_size			,loss.data[0]) )
            sys.stdout.flush()
            i = i + 1
            curr_loss += loss.data[0]

        epoch_loss = curr_loss / dataset_size
        print('Loss: {:.4f}'.format(epoch_loss))
        
        path = save_directory + '_batch_' + str(epoch) + '_loss_' + str(round(epoch_loss, 3)) + '.pth'
        torch.save(net.state_dict(), path)

        val_loss = evaluate(net, dataloader, loss_function, epoch)
        print('Validation Loss: {:.4f}'.format(val_loss))
    return net