This notebook is the place from where the network can be run and tested. Everything relating to the setup can be done here, while "specialised" code should be delegated to its own python file. Ideally the process that is run through here will then later be adapted to a 'main' execution file in Python that can be run from the command line.

In [None]:
# Import the (probably) necessary imports.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch.optim as optim
from torch.utils.data import DataLoader

from torchvision import transforms, utils

from skimage import io, transform

import os

# For Early Stopping Callback
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

# Ray tune for tuning the hyperparameters
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler
from ray.tune.integration.pytorch_lightning import TuneReportCallback, \
        TuneReportCheckpointCallback

# Probable project code structure
from project_code.utils import preprocessing
from project_code.data.zebrafish_data_module import *
from project_code.networks.rnn import *

In [None]:
# Setup tensorboard for easy debugging.
import tensorboard

# This might be a little different for Pytorch lightning.
# For one, the logs are stored in lightning_logs.
# For two, I don't know if we should still remove them in between.

%load_ext tensorboard
%tensorboard --logdir lightning_logs

# If you run this notebook locally, you can also access Tensorbaord at 127.0.0.1:6006 now.

# Clean up old logs.
if os.path.isdir('./lightning_logs/'):
  import shutil
  shutil.rmtree('lightning_logs/')

from torch.utils.tensorboard import SummaryWriter

# default 'log_dir' is "lightning_logs"
writer = SummaryWriter('lightning_logs')

In [None]:
# We want to define a simple DataModule and two Datasets
# the Datasets will contain generated sequences of white
# noise with different means. The idea is that the model
# should be able to differentiate between them and update
# their belief system. If we can see this, then we know
# that the model works.

import torch
from torch.utils.data import Dataset

import pytorch_lightning as pl

class TestDataset(Dataset):
    
    def __init__(self, 
                 mean1=1,
                 mean2=5,
                 num_samples=1000,
                 sequence_length=1000):
        self.mean1 = mean1
        self.mean2 = mean2
        self.num_samples = num_samples
        self.sequence_length = sequence_length
        
        # Generate the sequences and put them in an array.
        self.sequences = []
        self.sequences_target = []
        #output_shape = (self.sequence_length, 1)
        for i in range(self.num_samples):
            # Make sure that half is of one mean and half of the other.
            if i % 2 == 0:
                sequence = np.random.normal(mean1, 0.25, (self.sequence_length, 1))
                #target = np.ones(output_shape)
                target = np.ones((1,))
            else:
                sequence = np.random.normal(mean2, 0.25, (self.sequence_length, 1))
                #target = np.zeros(output_shape)
                target = np.zeros((1,))
                
            # Normalise the trace.
            #norm = np.linalg.norm(sequence)
            #sequence = sequence / norm
            
            # Use a low-pass filter on the noise such that
            # the confidence values also don't change as
            # quickly (and we can hopefully see better
            # integration over time).
            sequence = np.convolve(np.squeeze(sequence),
                                   np.ones(10).T/10,
                                   mode='same')
            sequence = np.expand_dims(sequence, axis=1)

            self.sequences.append(sequence)
            self.sequences_target.append(target)

            if i < 2:
                plt.figure()
                plt.plot(sequence)
                plt.show()
                
        
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        return (self.sequences[idx], self.sequences_target[idx])

class TestDataModule(pl.LightningDataModule):
   
    def __init__(self, 
                 mean1=1,
                 mean2=5, 
                 batch_size=32, 
                 num_samples=1000, 
                 sequence_length=1000):
        super().__init__()
        self.batch_size = batch_size
        self.mean1 = mean1
        self.mean2 = mean2
        self.num_samples = num_samples
        self.sequence_length = sequence_length
    
    def setup(self, stage = None):
        # Make assignments here (val/train/test split).
        # Called on every process in Distributed Data Processing.
        
        dataset = TestDataset(self.mean1, 
                              self.mean2, 
                              self.num_samples, 
                              self.sequence_length)
        
        # Like in the RestingDataset we go for a 80, 10, 10 split.
        num_train = round(self.num_samples * 0.8)
        num_val = round(self.num_samples * 0.1)
        num_test = round(self.num_samples * 0.1)
        
        # We could be missing samples due to rounding.
        # In that case we add it to the test set.
        num_test = num_test + self.num_samples \
                   - (num_train + num_val + num_test)
        
        self.train, self.val, self.test = \
                torch.utils.data.random_split(dataset, \
                [num_train, num_val, num_test])
        
    def train_dataloader(self):
        return DataLoader(self.train, num_workers=4, batch_size=self.batch_size)
    
    def val_dataloader(self):
        return DataLoader(self.val, num_workers=4, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test, num_workers=4, batch_size=self.batch_size)

In [None]:
from matplotlib.collections import LineCollection

# Utility function
def cart2pol(x, y):
    rho = np.sqrt(x**2 + y**2)
    phi = np.arctan2(y, x)
    return(rho, phi)

# 0 degrees alines with the x axis, increasing degrees
# means going rightward.
def pol2cart(rho, phi):
    # Convert degrees to radians.
    phi = np.deg2rad(phi)
    
    # Convert polar coordinates to cartesian coordinates.
    x = rho * np.cos(phi)
    y = rho * np.sin(phi)
    
    return(x, y)

# Update the orientation while staying between 0 and 360.
def update_orientation(orientation, angle):
    #print('Angle: {}'.format(angle))
    while angle >= 360 or angle < 0:
        if angle >= 360:
            angle = angle - 360
        else:
            angle = angle + 360
    #print('New orientation: {}'.format((orientation + angle) % 360))
    return (orientation + angle) % 360

# Modelling function
def generate_fish_trace(beta_parameters, exp_length, frequency):
    a, b = beta_parameters
    
    # Initialise an empty trace.
    simulated_trace = np.empty((exp_length * frequency, 2))
    angle = np.empty((exp_length * frequency, 1))
    x = simulated_trace[:, 0]
    y = simulated_trace[:, 1]
    
    # Set model parameters
    bout_freq = 1 # 1 bout per second.
    avg_bout_vel = 0.5 # 0.5mm per second.
    vel_duration = 1 # the fish will keep a velocity > 0 for 1 second after the bout.
    avg_peak_velocity = 0.5 # at the start of the bout, the average velocity is 0.5mm.
    
    # Initialise the fish at a random location.
    # Random location is a Gaussian around 0.
    x[0], y[0] = np.random.normal(0, 3), np.random.normal(0, 3)
    
    # Initialise the first bout.
    bout_countdown = int(np.random.normal(
                        loc=bout_freq * frequency,
                        scale=(frequency / 2) / bout_freq))
    if bout_countdown < 0:
        bout_countdown = 0
        
    # Initialise current velocity
    current_velocity = 0
    
    # Initialise current fish orientation (in degrees).
    # Bout angles are relative, so we need to keep
    # track of the orientation.
    current_orientation = np.random.random_sample() * 360
    #current_orientation = 0
    angle[0] = current_orientation
    
    # Simulate every measurement.
    for i in range(1, exp_length * frequency):
        # If the bout_countdown is 0, then we bout.
        if bout_countdown == 0:
            current_velocity = np.random.normal(
                                loc=avg_peak_velocity, 
                                scale=avg_peak_velocity*0.3)
            current_orientation = update_orientation(
                                    current_orientation,
                                    (np.random.beta(a, b) * 360) - 180)
            #current_velocity = np.random.beta(a, b)
            #current_orientation = update_orientation(
            #                        current_orientation,
            #                        np.random.normal(
            #                            scale=90)
            #                      )
            #if b == 20:
            #    current_orientation = update_orientation(current_orientation, 90)
            #    current_orientation = current_orientation + 90
            #else:
            #    current_orientation = update_orientation(current_orientation, -90)
            #    current_orientation = current_orientation - 90
            
            # We have to set a new timer.
            bout_countdown = int(np.random.normal(
                                loc=bout_freq * frequency,
                                scale=(frequency / 2) / bout_freq))
            if bout_countdown < 0:
                bout_countdown = 0
        else:
            bout_countdown = bout_countdown - 1
            
            # Update velocity, since it goes down over time.
            # We want to go back to 0 in 1 second on average.
            # We'll keep it simple and do it linearly.
            current_velocity = current_velocity \
                    - (avg_peak_velocity / frequency)
            if current_velocity < 0:
                current_velocity = 0
        
        # Update position.
        moved_x, moved_y = pol2cart(current_velocity, 
                                    current_orientation)
        
        x[i] = x[i-1] + moved_x
        y[i] = y[i-1] + moved_y
        
        angle[i] = current_orientation
        
        # Add white noise to the measurement.
        #x[i] = x[i] + np.random.normal(0, 0.1)
        #y[i] = y[i] + np.random.normal(0, 0.1)
        
    #plt.figure()
    #plt.plot(angle)
    #plt.show()
        
    return simulated_trace
        
# Run the simulation and plot the trace.
simulated_trace = generate_fish_trace((10, 10), 60, 100)

segments = np.zeros((simulated_trace.shape[0] - 1, 2, 2))
segments[:, 0, 0] = simulated_trace[:-1, 0]
segments[:, 0, 1] = simulated_trace[:-1, 1]
segments[:, 1, 0] = simulated_trace[1:, 0]
segments[:, 1, 1] = simulated_trace[1:, 1]

fig, axs = plt.subplots()
lc = LineCollection(segments, cmap='viridis')
#norm = plt.Normalize(exp.behavior_log.t.min(), exp.behavior_log.t.max())
#lc.set_array(norm(exp.behavior_log.t.tolist()))
lc.set_array(np.linspace(0, 60, len(simulated_trace)-1))
line = axs.add_collection(lc)
axs.set_xlim(simulated_trace[:, 0].min(), simulated_trace[:, 0].max())
axs.set_ylim(simulated_trace[:, 1].min(), simulated_trace[:, 1].max())
plt.colorbar(line)
plt.title('Simulated fish 1')
plt.show()

# Run the simulation and plot the trace.
simulated_trace = generate_fish_trace((30, 15), 60, 100)

segments = np.zeros((simulated_trace.shape[0] - 1, 2, 2))
segments[:, 0, 0] = simulated_trace[:-1, 0]
segments[:, 0, 1] = simulated_trace[:-1, 1]
segments[:, 1, 0] = simulated_trace[1:, 0]
segments[:, 1, 1] = simulated_trace[1:, 1]

fig, axs = plt.subplots()
lc = LineCollection(segments, cmap='viridis')
#norm = plt.Normalize(exp.behavior_log.t.min(), exp.behavior_log.t.max())
#lc.set_array(norm(exp.behavior_log.t.tolist()))
lc.set_array(np.linspace(0, 60, len(simulated_trace)-1))
line = axs.add_collection(lc)
axs.set_xlim(simulated_trace[:, 0].min(), simulated_trace[:, 0].max())
axs.set_ylim(simulated_trace[:, 1].min(), simulated_trace[:, 1].max())
plt.colorbar(line)
plt.title('Simulated fish 2')
plt.show()

In [None]:
# We want to define a simple DataModule and two Datasets
# the Datasets will contain generated traces of zebrafish
# with different behaviour. The idea is that the model
# should be able to differentiate between them and update
# their belief system. If we can see this, then we know
# that the model works.

import torch
from torch.utils.data import Dataset

import pytorch_lightning as pl

class ModelledZebrafishDataset(Dataset):
    
    def __init__(self,
                 num_samples=100,
                 exp_length=60):
        
        # Number of fish.
        self.num_samples = num_samples
        
        # Length of the experiment in seconds.
        self.exp_length = exp_length
        
        # We "measure" at a frequency of 100Hz.
        frequency = 25
        
        # Generate the traces and put them in an array.
        self.traces = []
        self.traces_target = []
        for i in range(self.num_samples):
            # Make sure that half is of one type of behaviour and half of the other.
            if i % 2 == 0:
                trace = generate_fish_trace((10, 10), self.exp_length, frequency)
                #trace = np.ones((self.exp_length * frequency, 1))
                target = np.ones((1,))
            else:
                trace = generate_fish_trace((30, 15), self.exp_length, frequency)
                #trace = np.zeros((self.exp_length * frequency, 1)) + 0.01
                target = np.zeros((1,))
                
            # Normalise the trace.
            #norm = np.linalg.norm(trace)
            #trace = trace / norm
            
            extracted_angles = np.empty((len(trace), 1))
            
            last_angle = 0
            relative_angle = 0
            
            simulated_trace = trace

            for i in range(1, len(simulated_trace)):
                current_x = simulated_trace[i, 0]
                current_y = simulated_trace[i, 1]

                previous_x = simulated_trace[i-1, 0]
                previous_y = simulated_trace[i-1, 1]

                # Calculate the relative change in position.
                rel_x = current_x - previous_x
                rel_y = current_y - previous_y

                # Calculate the polar coordinates of this.
                rho, phi = cart2pol(rel_x, rel_y)
                
                # Convert the angle to degrees for readability.
                phi = np.rad2deg(phi)

                # Round Phi, since we might have a bit of noise.
                phi = round(phi)
                # The angle calculation contains some noise, so it can occur
                # that the angle goes from -180 to 180 or the other way around.
                # In order not to extract weird angles, check for this.
                if phi != last_angle and not ((phi == 180 and last_angle == -180) or phi == -180 and last_angle == -180) and rho > 0:
                    # Take the (180, -180) range into account.
                    if phi - last_angle > 180:
                        relative_angle = -1 * ((phi - last_angle) % 180)
                    elif phi - last_angle < -180:
                        relative_angle = -1 * ((phi - last_angle) % -180)
                    else:
                        relative_angle = phi - last_angle
                    #print('Last angle: {}'.format(last_angle))
                    last_angle = phi
                    #print('Phi: {}'.format(phi))
                    #print('Last angle: {}'.format(np.rad2deg(last_angle)))
                    #print('Relative angle: {}'.format(relative_angle))
        
                extracted_angles[i] = relative_angle

            self.traces.append(simulated_trace)
            #self.traces.append(extracted_angles)
            self.traces_target.append(target)
            
            #plt.figure()
            #plt.plot(extracted_angles)
            #plt.show()
                
        
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        return (self.traces[idx], self.traces_target[idx])

class ModelledZebrafishDataModule(pl.LightningDataModule):
   
    def __init__(self,
                 batch_size=32, 
                 num_samples=100, 
                 exp_length=60):
        super().__init__()
        self.batch_size = batch_size
        self.num_samples = num_samples
        self.exp_length = exp_length
    
    def setup(self, stage = None):
        # Make assignments here (val/train/test split).
        # Called on every process in Distributed Data Processing.
        
        dataset = ModelledZebrafishDataset(self.num_samples, 
                                           self.exp_length)
        
        # Like in the RestingDataset we go for a 80, 10, 10 split.
        num_train = round(self.num_samples * 0.8)
        num_val = round(self.num_samples * 0.1)
        num_test = round(self.num_samples * 0.1)
        
        # We could be missing samples due to rounding.
        # In that case we add it to the test set.
        num_test = num_test + self.num_samples \
                   - (num_train + num_val + num_test)
        
        self.train, self.val, self.test = \
                torch.utils.data.random_split(dataset, \
                [num_train, num_val, num_test])
        
    def train_dataloader(self):
        return DataLoader(self.train, num_workers=4, batch_size=self.batch_size)
    
    def val_dataloader(self):
        return DataLoader(self.val, num_workers=4, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test, num_workers=4, batch_size=self.batch_size)

In [None]:
# Do hyperparameter tuning with Ray Tune.

# Define logging callback.
tune_report_callback = TuneReportCallback({
    'loss': 'val_loss'
}, on='validation_end')

def train_with_tune(config, num_epochs=40, num_gpus=0):
    # Initialise the model
    model = MutationNet(config)

    # The model needs to use double (instead of float)
    model = model.double()

    # Initialise the data.
    #data_module = ZebrafishDataModule(batch_size=1, sampling_frequency=100)
    data_module = TestDataModule(mean1=1, mean2=1.5, num_samples=1000, sequence_length=1000)

    # Train the model.
    trainer = pl.Trainer(callbacks=[tune_report_callback],
                         progress_bar_refresh_rate = 0,
                        )
    #trainer = pl.Trainer(fast_dev_run=True)
    trainer.fit(model, data_module)
    
def tune_model_asha(num_samples=10, num_epochs=10, gpus_per_trial=0):
    config = {
        'model_type': tune.choice(['LSTM', 'GRU']),
        'input_size': 1,
        'hidden_size': tune.lograndint(1, 100),
        'num_layers': tune.lograndint(1, 10),
        'learning_rate': tune.loguniform(1e-4, 1e-1)
    }
    
    scheduler = ASHAScheduler(
        max_t = num_epochs,
        grace_period = 1,
        reduction_factor = 2)
    
    reporter = CLIReporter(
        parameter_columns=['model_type', 'hidden_size', 'num_layers', 'learning_rate'],
        metric_columns=['loss', 'training_iteration'])
    
    analysis = tune.run(
        tune.with_parameters(
            train_with_tune,
            num_epochs = num_epochs,
            num_gpus = gpus_per_trial),
        resources_per_trial = {
            'cpu': 1,
            'gpu': gpus_per_trial
        },
        metric = 'loss',
        mode = 'min',
        config = config,
        num_samples = num_samples,
        checkpoint_at_end = True,
        scheduler = scheduler,
        progress_reporter = reporter,
        name = 'tune_model_asha')
    
    print('Best hyperparameters found were: ', analysis.best_config)
    
    # Return the best config.
    return analysis.best_config

config = tune_model_asha(num_samples=3, gpus_per_trial=1)

In [None]:
# For individual training

# Specify the parameters for the model.
config = {
    'model_type': 'LSTM',
    'input_size': 1,
    'hidden_size': 10,
    'num_layers': 1,
    'learning_rate': 1e-1
}

# Initialise the model with the previously found best config.
model = MutationNet(config)

# The model needs to use double (instead of float)
model = model.double()

# Initialise the data.
#data_module = ZebrafishDataModule(batch_size=1)
data_module = TestDataModule(mean1=1, mean2=5, num_samples=100, sequence_length=1000)
#data_module = ModelledZebrafishDataModule(batch_size=32, num_samples=1000, exp_length=60)

# Train the model.
trainer = pl.Trainer(max_epochs=15, callbacks=[EarlyStopping(monitor="val_loss", patience=30)], accelerator="gpu", devices=1)
#trainer = pl.Trainer(fast_dev_run=True)
trainer.fit(model, data_module)

# Test the model
#trainer.test(datamodule=data_module)

In [None]:
print(torch.version.cuda)

In [None]:
sequence = np.random.normal(5, 0.25, (1000, 2))

# Run the simulation and plot the trace.
#simulated_trace, angles = generate_fish_trace1((3.2, 20), 60, 100)

x = np.expand_dims(sequence, axis=0)
print(x.shape)
x = x.astype(np.float32)
model.float()
a, b = model(torch.Tensor(x).float())

print(a)

a = a.detach().numpy()
a = np.squeeze(a)

#prediction = prediction.detach().numpy()

fig, ax1 = plt.subplots()
ax1.plot(a, color='red')

ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis

ax2.plot(sequence)

fig.tight_layout()  # otherwise the right y-label is slightly clipped
plt.show()

print(sequence.shape)

In [None]:
# Reintialise trainer such that the logs will be stored in a new version.
#trainer = pl.Trainer()
#trainer.test(model, datamodule=TestDataModule(mean1=1, mean2=1.35))

In [None]:
# Test Beta distributions
from scipy.stats import beta
import matplotlib.pyplot as plt
fig, ax = plt.subplots(1, 1)

# Plot beta distribution for the first fish.
a, b = 10, 10
mean, var, skew, kurt = beta.stats(a, b, moments='mvsk')

x = np.linspace(0, 1, 100)
ax.plot((x * 360) - 180, beta.pdf(x, a, b),
        'r-', lw=5, alpha=0.6, label='beta(10, 10)')

# Plot beta distribution for the second fish.
a, b = 30, 15
mean, var, skew, kurt = beta.stats(a, b, moments='mvsk')

ax.plot((x * 360) - 180, beta.pdf(x, a, b),
        'b-', lw=5, alpha=0.6, label='beta(30, 15)')

plt.xlabel('Angle in degrees')
plt.ylabel('Sample weight')

plt.legend()

plt.rcParams.update({'font.size': 16})

plt.title('Bout angle beta distribution')

plt.show()

In [None]:
%matplotlib inline

# Modelling function
def generate_fish_trace1(beta_parameters, exp_length, frequency):
    a, b = beta_parameters
    
    # Initialise an empty trace.
    simulated_trace = np.empty((exp_length * frequency, 2))
    angle = np.empty((exp_length * frequency, 1))
    x = simulated_trace[:, 0]
    y = simulated_trace[:, 1]
    
    # Set model parameters
    bout_freq = 1 # 1 bout per second.
    avg_bout_vel = 0.5 # 0.5mm per second.
    vel_duration = 1 # the fish will keep a velocity > 0 for 1 second after the bout.
    avg_peak_velocity = 0.5 # at the start of the bout, the average velocity is 0.5mm.
    
    # Initialise the fish at a random location.
    # Random location is a Gaussian around 0.
    x[0], y[0] = np.random.normal(0, 3), np.random.normal(0, 3)
    
    # Initialise the first bout.
    bout_countdown = int(np.random.normal(
                        loc=bout_freq * frequency,
                        scale=(frequency / 2) / bout_freq))
    if bout_countdown < 0:
        bout_countdown = 0
        
    # Initialise current velocity
    current_velocity = 0
    
    # Initialise current fish orientation (in degrees).
    # Bout angles are relative, so we need to keep
    # track of the orientation.
    current_orientation = np.random.random_sample() * 360
    #current_orientation = 0
    angle[0] = current_velocity
    
    # Simulate every measurement.
    for i in range(1, exp_length * frequency):
        # If the bout_countdown is 0, then we bout.
        if bout_countdown == 0:
            current_velocity = np.random.normal(
                                loc=avg_peak_velocity, 
                                scale=avg_peak_velocity*0.3)
            current_orientation = update_orientation(
                                    current_orientation,
                                    (np.random.beta(a, b) * 360) - 180)
            #current_velocity = np.random.beta(a, b)
            #current_orientation = update_orientation(
            #                        current_orientation,
            #                        np.random.normal(
            #                            scale=90)
            #)
             
            
            #if b == 20:
            #    current_orientation = update_orientation(current_orientation, 90)
            #    current_orientation = current_orientation + 90
            #else:
            #    current_orientation = update_orientation(current_orientation, -90)
            #    current_orientation = current_orientation - 90
            
            # We have to set a new timer.
            bout_countdown = int(np.random.normal(
                                loc=bout_freq * frequency,
                                scale=(frequency / 2) / bout_freq))
            if bout_countdown < 0:
                bout_countdown = 0
        else:
            bout_countdown = bout_countdown - 1
            
            # Update velocity, since it goes down over time.
            # We want to go back to 0 in 1 second on average.
            # We'll keep it simple and do it linearly.
            current_velocity = current_velocity \
                    - (avg_peak_velocity / frequency)
            if current_velocity < 0:
                current_velocity = 0
        
        # Update position.
        moved_x, moved_y = pol2cart(current_velocity, 
                                    current_orientation)
        
        x[i] = x[i-1] + moved_x
        y[i] = y[i-1] + moved_y
        
        angle[i] = current_orientation
        
        # Add white noise to the measurement.
        #x[i] = x[i] + np.random.normal(0, 0.1)
        #y[i] = y[i] + np.random.normal(0, 0.1)
        
    #plt.figure()
    #plt.plot(angle)
    #plt.show()
        
    return simulated_trace, angle
        
# Run the simulation and plot the trace.
simulated_trace, angles = generate_fish_trace1((10, 10), 60, 100)

#simulated_trace, angles = generate_fish_trace1((20, 5.2), 60, 100)
extracted_angles = []
extracted_velocities = []

last_angle = 0
relative_angle = 0

extracted_angles = np.empty((len(simulated_trace), 1))
            
last_angle = 0
relative_angle = 0

for i in range(1, len(simulated_trace)):
    current_x = simulated_trace[i, 0]
    current_y = simulated_trace[i, 1]

    previous_x = simulated_trace[i-1, 0]
    previous_y = simulated_trace[i-1, 1]

    # Calculate the relative change in position.
    rel_x = current_x - previous_x
    rel_y = current_y - previous_y

    # Calculate the polar coordinates of this.
    rho, phi = cart2pol(rel_x, rel_y)
    
    # Convert the angle to degrees for readability.
    phi = np.rad2deg(phi)
    
    # Round Phi, since we might have a bit of noise.
    phi = round(phi)
    # The angle calculation contains some noise, so it can occur
    # that the angle goes from -180 to 180 or the other way around.
    # In order not to extract weird angles, check for this.
    if phi != last_angle and not ((phi == 180 and last_angle == -180) or phi == -180 and last_angle == -180) and rho > 0:
        # Take the (180, -180) range into account.
        if phi - last_angle > 180:
            relative_angle = -1 * ((phi - last_angle) % 180)
        elif phi - last_angle < -180:
            relative_angle = -1 * ((phi - last_angle) % -180)
        else:
            relative_angle = phi - last_angle
        #print('Last angle: {}'.format(last_angle))
        last_angle = phi
        #print('Phi: {}'.format(phi))
        #print('Last angle: {}'.format(np.rad2deg(last_angle)))
        #print('Relative angle: {}'.format(relative_angle))
        
        
    extracted_angles[i] = relative_angle

#norm = np.linalg.norm(simulated_trace)
#simulated_trace = simulated_trace / norm

x = np.expand_dims(simulated_trace, axis=0)
print(f'Shape of x: {x.shape}')
x = x.astype(np.float32)
model.float()
a, b = model(torch.Tensor(x).float())

print(a.shape)
print(b.shape)

a = b.detach().numpy()
a = np.squeeze(a)

#prediction = prediction.detach().numpy()

fig, ax1 = plt.subplots()
color = 'tab:red'
ax1.set_xlabel('Time (s)')
ax1.set_ylabel('Confidence value', color=color)
ax1.plot(np.linspace(0, 60, 6000), a, color=color)
ax1.tick_params(axis='y', labelcolor=color)

ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis

color = 'tab:blue'
ax2.set_ylabel('Angle in degrees', color=color)
ax2.set_ylim([-180, 180])
ax2.plot(np.linspace(0, 60, 6000), extracted_angles, color=color)
ax2.tick_params(axis='y', labelcolor=color)

plt.title('Confidence values for extracted bout angles')

# Make the plots readable on the presentation.
plt.rcParams.update({'font.size': 16})

fig.tight_layout()  # otherwise the right y-label is slightly clipped
plt.show()

if True:
    segments = np.zeros((simulated_trace.shape[0] - 1, 2, 2))
    segments[:, 0, 0] = simulated_trace[:-1, 0]
    segments[:, 0, 1] = simulated_trace[:-1, 1]
    segments[:, 1, 0] = simulated_trace[1:, 0]
    segments[:, 1, 1] = simulated_trace[1:, 1]

    fig, axs = plt.subplots()
    lc = LineCollection(segments, cmap='viridis')
    #norm = plt.Normalize(exp.behavior_log.t.min(), exp.behavior_log.t.max())
    #lc.set_array(norm(exp.behavior_log.t.tolist()))
    lc.set_array(np.linspace(0, 60, len(simulated_trace)-1))
    line = axs.add_collection(lc)
    axs.set_xlim(simulated_trace[:, 0].min(), simulated_trace[:, 0].max())
    axs.set_ylim(simulated_trace[:, 1].min(), simulated_trace[:, 1].max())
    plt.colorbar(line)
    plt.title('Simulated fish 2')
    plt.show()

In [None]:
lstm = model.lstm

for i in lstm.named_parameters():
    print(i)
    
print('break')
    
for i in model.linear.named_parameters():
    print(i)
    
print(model.linear.named_parameters())

In [None]:
# Try to extract the angles from the trace.
simulated_trace, angles = generate_fish_trace1((20, 5.2), 60, 100)
extracted_angles = []
extracted_velocities = []

last_angle = 0
relative_angle = 0

for i in range(1, len(simulated_trace)):
    current_x = simulated_trace[i, 0]
    current_y = simulated_trace[i, 1]
    
    previous_x = simulated_trace[i-1, 0]
    previous_y = simulated_trace[i-1, 1]
    
    # Calculate the relative change in position.
    rel_x = current_x - previous_x
    rel_y = current_y - previous_y
    
    # Calculate the polar coordinates of this.
    rho, phi = cart2pol(rel_x, rel_y)
    if phi != last_angle:
        relative_angle = phi - last_angle
        last_angle = phi
    extracted_angles.append(relative_angle)
    #extracted_angles.append(phi-extracted_angles[-1])
    extracted_velocities.append(rho)

plt.figure()
plt.plot(np.rad2deg(extracted_angles))
plt.show()

plt.figure()
plt.plot(extracted_velocities)
plt.show()

extracted_angles2 = extracted_angles

In [None]:
#trace1 = simulated_trace
#extracted_angles2 = extracted_angles

plt.figure()
plt.plot(extracted_angles1)
plt.plot(extracted_angles2, color='red')
plt.show()