<a href="https://colab.research.google.com/github/noaschaffer/flexible_network/blob/main/run.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from net import Encoder
from net import Transition
from data import RectsData
from data import IMAGE_SIZE, FRAME_NUM

import torch
import itertools
import matplotlib.pyplot as plt
import numpy as np


def loss_function(e_x_t, t_e_x_t, e_x_t1):
    loss = 0
    all_e = torch.cat((e_x_t, e_x_t1[-1].unsqueeze(0)))
    for i in range(len(t_e_x_t)):
        loss += -torch.log(torch.exp(-(t_e_x_t[i] - e_x_t1[i]) ** 2) / torch.sum(torch.exp(-(t_e_x_t[i] - all_e) ** 2)))
    loss = loss / len(t_e_x_t)
    return loss


def train(encoder, transition, dataset, optimizer):
    encoder.train()
    transition.train()
    clip_num = 0
    errors = []
    for clip in dataset.data:
        print('clip number {}'.format(clip_num))
        clip_num += 1
        clip_batches = []
        initial = 0
        for i in range(NUM_OF_BATCHES):  # create mini batches for each clip
            mini_batch = clip[initial: initial + HP_DICT['batch_size']]
            clip_batches.append(mini_batch)
            initial += HP_DICT['step_size']

        for i in range(NUM_OF_BATCHES):
            x_t = clip_batches[i][:-1].view(-1, 1, IMAGE_SIZE, IMAGE_SIZE)
            x_t1 = clip_batches[i][1:].view(-1, 1, IMAGE_SIZE, IMAGE_SIZE)

            for training_step in range(HP_DICT['training_steps']):
                optimizer.zero_grad()
                e_x_t = encoder(x_t)
                t_e_x_t = transition(e_x_t)
                e_x_t1 = encoder(x_t1)
                loss = loss_function(e_x_t, t_e_x_t, e_x_t1)
                print(loss.item())
                errors.append(loss.item())
                loss.backward()
                optimizer.step()
            eval(encoder, transition, clip, i)
    return errors


def eval(encoder, transition, clip, batch_index):
    plt.clf()
    x_t = clip.view(-1, 1, IMAGE_SIZE, IMAGE_SIZE)
    e_x_t = encoder(x_t)
    t_e_x_t = transition(e_x_t).detach().flatten().numpy()
    e_x_t = e_x_t.detach().flatten().numpy()
    plt.plot(e_x_t, 'b.', label='encoder output')
    plt.plot(np.arange(1, len(e_x_t)), t_e_x_t[:-1], 'r.', label='prediction')
    plt.xlabel('frame')
    plt.ylabel('representation')
    y_lim = plt.gca().get_ylim()
    plt.fill_between(
        np.arange(batch_index * HP_DICT['step_size'], batch_index * HP_DICT['step_size'] + HP_DICT['batch_size']),
        y_lim[0], y_lim[1], color='orange', alpha=0.25, label='training frame')
    plt.legend()
    plt.show()
    plt.pause(0.1)


def experiment():
    E_net = Encoder().double()
    T_net = Transition().double()

    optimizer_predict = torch.optim.RMSprop(itertools.chain(E_net.parameters(), T_net.parameters()),
                                            lr=HP_DICT['learning_rate'])
    data = RectsData(HP_DICT)
    if HP_DICT['GPU']:
        E_net = E_net.to('cuda')
        T_net = T_net.to('cuda')
        data = data.to('cuda')
    plt.ion()
    return train(E_net, T_net, data, optimizer_predict)




clip number 0
1.9417942066782807
3.436908554670263
1.986989131897622
2.0585444395012185
1.9637777782849437
1.9513074864537303
1.9467240628088336
1.9373381918116042
1.9285353398253555
1.9235459050491635
1.9104012404977757
1.895578714730097
1.8957807925411156
1.8722889654012926
1.884533686887486
1.9009221958644311
1.8798577541137351
1.862989822294497
1.938060040414653
1.9000251693571428
1.8722820355124832
1.910635174083937
1.8816511743432578
1.8603227036334455
1.903133634578382
1.872893713798024
1.8756904312752634
1.8953401548061521
1.8697177776434646
1.8532265962554213
1.9818153898629722
1.8931363630958142
1.8672024480132527
clip number 1
1.9426130799348351
2.103353708303939
1.904190732214995
1.9400832839699227
1.906084921224312
1.8777863326161273
1.9220060044759488
1.8946899767476506
1.8692993243111402
1.9109032827970711
1.8862900635565278
1.8585967708461772
1.9168261405603104
1.8852076449715713
1.8602145597036905
1.924783389901697
1.8783471616546947
1.8538210137004787
1.93741262746278

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

In [None]:
if __name__ == '__main__':
    HP_DICT = {'batch_size': 7, 'step_size': 1, 'training_steps': 3, 'learning_rate': 1e-3, 'GPU': False,
               'samples_num': 100, 'switch_points': [40, 80]}
    NUM_OF_BATCHES = int((FRAME_NUM - HP_DICT['batch_size']) / HP_DICT['step_size']) + 1
    errors = []
    for i in range(3):
        errors.append(experiment())
    errors = np.mean(np.array(errors), axis=0)
    plt.plot(np.linspace(0, 100, len(errors)), np.log(errors))
    plt.show()
